From f047094af11295e44cfc601db0b055eec1aa9c55 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Mon, 22 Apr 2024 20:04:44 +0200 Subject: [PATCH] keep iterating for the point prompts --- src/main/java/ai/nets/samj/EfficientSamJ.java | 20 ++++++++---- .../java/ai/nets/samj/EfficientViTSamJ.java | 32 ++++++++++++------- 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/src/main/java/ai/nets/samj/EfficientSamJ.java b/src/main/java/ai/nets/samj/EfficientSamJ.java index 7d23c5b..6d38579 100644 --- a/src/main/java/ai/nets/samj/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/EfficientSamJ.java @@ -36,6 +36,7 @@ import io.bioimage.modelrunner.apposed.appose.Service.Task; import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.utils.CommonUtils; import net.imglib2.RandomAccessibleInterval; import net.imglib2.converter.RealTypeConverters; import net.imglib2.img.array.ArrayImgs; @@ -277,6 +278,7 @@ void addImage(RandomAccessibleInterval rai) if (rai.dimensionsAsLongArray()[0] * rai.dimensionsAsLongArray()[1] > MAX_ENCODED_AREA_RS * MAX_ENCODED_AREA_RS || rai.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || rai.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) { this.targetDims = new long[] {0, 0, 0}; + this.img = rai; return; } this.script = ""; @@ -311,7 +313,7 @@ private void reencodeCrop() throws IOException, InterruptedException, RuntimeExc private void reencodeCrop(long[] cropSize) throws IOException, InterruptedException, RuntimeException { this.script = ""; - sendCropAsNp(); + sendCropAsNp(cropSize); this.script += "" + "task.update(str(im.shape))" + System.lineSeparator() + "aa = predictor.get_image_embeddings(im[None, ...])"; @@ -611,7 +613,8 @@ public Rectangle getCurrentlyEncodedArea() { int xMargin = (int) (targetDims[1] * 0.1); int yMargin = (int) (targetDims[0] * 0.1); Rectangle alreadyEncoded; - if (encodeCoords[0] != 0 || encodeCoords[1] != 0 || targetDims[1] != this.img.dimensionsAsLongArray()[1]) { + if (encodeCoords[0] != 0 || encodeCoords[1] != 0 || targetDims[1] != this.img.dimensionsAsLongArray()[1] + || targetDims[0] != this.img.dimensionsAsLongArray()[0]) { alreadyEncoded = new Rectangle((int) encodeCoords[0] + xMargin / 2, (int) encodeCoords[1] + yMargin / 2, (int) targetDims[1] - xMargin, (int) targetDims[0] - yMargin); } else { @@ -640,7 +643,8 @@ private void evaluateReencodingNeeded(List pointsList, List points Rectangle neededArea = getApproximateAreaNeeded(pointsList, pointsNegList, rect); ArrayList notInRect = getPointsNotInRect(pointsList, pointsNegList, rect); if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y - && alreadyEncoded.width <= rect.width && alreadyEncoded.height <= rect.width + && alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x + && alreadyEncoded.height + alreadyEncoded.y >= rect.width + rect.y && alreadyEncoded.width * 0.9 < rect.width && alreadyEncoded.height * 0.9 < rect.height && notInRect.size() == 0) { return; @@ -648,12 +652,14 @@ private void evaluateReencodingNeeded(List pointsList, List points this.encodeCoords = new long[] {rect.x, rect.y}; this.reencodeCrop(new long[] {rect.width, rect.height}); } else if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y - && alreadyEncoded.width <= rect.width && alreadyEncoded.height <= rect.width + && alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x + && alreadyEncoded.height + alreadyEncoded.y >= rect.width + rect.y && (alreadyEncoded.width * 0.9 > rect.width || alreadyEncoded.height * 0.9 > rect.height)) { this.encodeCoords = new long[] {rect.x, rect.y}; this.reencodeCrop(new long[] {rect.width, rect.height}); } else if (alreadyEncoded.x <= neededArea.x && alreadyEncoded.y <= neededArea.y - && alreadyEncoded.width <= neededArea.width && alreadyEncoded.height <= neededArea.width + && alreadyEncoded.width + alreadyEncoded.x >= neededArea.width + neededArea.x + && alreadyEncoded.height + alreadyEncoded.y >= neededArea.width + neededArea.y && alreadyEncoded.width * 0.9 < neededArea.width && alreadyEncoded.height * 0.9 < neededArea.height && notInRect.size() == 0) { return; @@ -711,6 +717,8 @@ private Rectangle getApproximateAreaNeeded(List pointsList, List p } minX = (int) Math.max(0, minX - Math.max(focusedArea.width * 0.1, ENCODE_MARGIN)); minY = (int) Math.max(0, minY - Math.max(focusedArea.height * 0.1, ENCODE_MARGIN)); + maxX = (int) Math.min(img.dimensionsAsLongArray()[1], maxX + Math.max(focusedArea.width * 0.1, ENCODE_MARGIN)); + maxY = (int) Math.min(img.dimensionsAsLongArray()[0], maxY + Math.max(focusedArea.height * 0.1, ENCODE_MARGIN)); Rectangle rect = new Rectangle(); rect.x = minX; rect.y = minY; @@ -953,7 +961,7 @@ private & NativeType> void sendCropAsNp(long[] cropSiz + ")" + System.lineSeparator(); int size = 1; for (long l : targetDims) {size *= l;} - code += "im = np.ndarray(" + size + ", dtype='float32', buffer=im_shm.buf).reshape(["; + code += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataType(Util.getTypeFromInterval(crop)) + "', buffer=im_shm.buf).reshape(["; for (long ll : targetDims) code += ll + ", "; code = code.substring(0, code.length() - 2); diff --git a/src/main/java/ai/nets/samj/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/EfficientViTSamJ.java index 0b04d88..979bb3d 100644 --- a/src/main/java/ai/nets/samj/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/EfficientViTSamJ.java @@ -38,6 +38,7 @@ import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; +import io.bioimage.modelrunner.utils.CommonUtils; import net.imglib2.RandomAccessibleInterval; import net.imglib2.converter.RealTypeConverters; import net.imglib2.img.array.ArrayImgs; @@ -373,6 +374,7 @@ void addImage(RandomAccessibleInterval rai) if (rai.dimensionsAsLongArray()[0] * rai.dimensionsAsLongArray()[1] > MAX_ENCODED_AREA_RS * MAX_ENCODED_AREA_RS || rai.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || rai.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) { this.targetDims = new long[] {0, 0, 0}; + this.img = rai; return; } this.script = ""; @@ -407,10 +409,10 @@ private void reencodeCrop() throws IOException, InterruptedException, RuntimeExc private void reencodeCrop(long[] cropSize) throws IOException, InterruptedException, RuntimeException { this.script = ""; - sendCropAsNp(); + sendCropAsNp(cropSize); this.script += "" + "task.update(str(im.shape))" + System.lineSeparator() - + "aa = predictor.get_image_embeddings(im[None, ...])"; + + "predictor.set_image(im)"; try { printScript(script, "Creation of the cropped embeddings"); Task task = python.task(script); @@ -679,7 +681,8 @@ public Rectangle getCurrentlyEncodedArea() { int xMargin = (int) (targetDims[1] * 0.1); int yMargin = (int) (targetDims[0] * 0.1); Rectangle alreadyEncoded; - if (encodeCoords[0] != 0 || encodeCoords[1] != 0 || targetDims[1] != this.img.dimensionsAsLongArray()[1]) { + if (encodeCoords[0] != 0 || encodeCoords[1] != 0 || targetDims[1] != this.img.dimensionsAsLongArray()[1] + || targetDims[0] != this.img.dimensionsAsLongArray()[0]) { alreadyEncoded = new Rectangle((int) encodeCoords[0] + xMargin / 2, (int) encodeCoords[1] + yMargin / 2, (int) targetDims[1] - xMargin, (int) targetDims[0] - yMargin); } else { @@ -708,7 +711,8 @@ private void evaluateReencodingNeeded(List pointsList, List points Rectangle neededArea = getApproximateAreaNeeded(pointsList, pointsNegList, rect); ArrayList notInRect = getPointsNotInRect(pointsList, pointsNegList, rect); if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y - && alreadyEncoded.width <= rect.width && alreadyEncoded.height <= rect.width + && alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x + && alreadyEncoded.height + alreadyEncoded.y >= rect.width + rect.y && alreadyEncoded.width * 0.9 < rect.width && alreadyEncoded.height * 0.9 < rect.height && notInRect.size() == 0) { return; @@ -716,12 +720,14 @@ private void evaluateReencodingNeeded(List pointsList, List points this.encodeCoords = new long[] {rect.x, rect.y}; this.reencodeCrop(new long[] {rect.width, rect.height}); } else if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y - && alreadyEncoded.width <= rect.width && alreadyEncoded.height <= rect.width + && alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x + && alreadyEncoded.height + alreadyEncoded.y >= rect.width + rect.y && (alreadyEncoded.width * 0.9 > rect.width || alreadyEncoded.height * 0.9 > rect.height)) { this.encodeCoords = new long[] {rect.x, rect.y}; this.reencodeCrop(new long[] {rect.width, rect.height}); } else if (alreadyEncoded.x <= neededArea.x && alreadyEncoded.y <= neededArea.y - && alreadyEncoded.width <= neededArea.width && alreadyEncoded.height <= neededArea.width + && alreadyEncoded.width + alreadyEncoded.x >= neededArea.width + neededArea.x + && alreadyEncoded.height + alreadyEncoded.y >= neededArea.width + neededArea.y && alreadyEncoded.width * 0.9 < neededArea.width && alreadyEncoded.height * 0.9 < neededArea.height && notInRect.size() == 0) { return; @@ -751,6 +757,8 @@ private Rectangle getApproximateAreaNeeded(List pointsList, List p } minX = (int) Math.max(0, minX - Math.max(focusedArea.width * 0.1, ENCODE_MARGIN)); minY = (int) Math.max(0, minY - Math.max(focusedArea.height * 0.1, ENCODE_MARGIN)); + maxX = (int) Math.min(img.dimensionsAsLongArray()[1], maxX + Math.max(focusedArea.width * 0.1, ENCODE_MARGIN)); + maxY = (int) Math.min(img.dimensionsAsLongArray()[0], maxY + Math.max(focusedArea.height * 0.1, ENCODE_MARGIN)); Rectangle rect = new Rectangle(); rect.x = minX; rect.y = minY; @@ -910,11 +918,13 @@ void sendCropAsNp() { private & NativeType> void sendCropAsNp(long[] cropSize) { if (cropSize == null) cropSize = new long[] {encodeCoords[3] - encodeCoords[1], encodeCoords[2] - encodeCoords[0], 3}; - //RandomAccessibleInterval crop = - // Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, interValSize ); + else if (cropSize.length == 2) + cropSize = new long[] {cropSize[1], cropSize[0], 3}; + RandomAccessibleInterval crop = + Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize ); - //RandomAccessibleInterval crop = Views.offsetInterval(crop, new long[] {encodeCoords[1], encodeCoords[0], 0}, interValSize); - RandomAccessibleInterval crop = Views.offsetInterval(Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize); + crop = Views.offsetInterval(crop, new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize); + //RandomAccessibleInterval crop = Views.offsetInterval(Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize); targetDims = crop.dimensionsAsLongArray(); shma = SharedMemoryArray.buildMemorySegmentForImage(new long[] {targetDims[0], targetDims[1], targetDims[2]}, Util.getTypeFromInterval(crop)); @@ -928,7 +938,7 @@ private & NativeType> void sendCropAsNp(long[] cropSiz + ")" + System.lineSeparator(); int size = 1; for (long l : targetDims) {size *= l;} - code += "im = np.ndarray(" + size + ", dtype='float32', buffer=im_shm.buf).reshape(["; + code += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataType(Util.getTypeFromInterval(crop)) + "', buffer=im_shm.buf).reshape(["; for (long ll : targetDims) code += ll + ", "; code = code.substring(0, code.length() - 2);