From e131eb8c7bf99430a5ca774255cd1f3a18600337 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 24 Apr 2024 00:52:05 +0200 Subject: [PATCH] correct small issues that maje eery model work on big images --- .../ai/nets/samj/models/AbstractSamJ.java | 35 +++++++++++++------ .../ai/nets/samj/models/EfficientSamJ.java | 3 +- .../ai/nets/samj/models/EfficientViTSamJ.java | 2 +- 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/main/java/ai/nets/samj/models/AbstractSamJ.java b/src/main/java/ai/nets/samj/models/AbstractSamJ.java index a2f5066..f2fbd70 100644 --- a/src/main/java/ai/nets/samj/models/AbstractSamJ.java +++ b/src/main/java/ai/nets/samj/models/AbstractSamJ.java @@ -28,8 +28,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; - - +import java.util.stream.Collectors; import java.awt.Polygon; import java.awt.Rectangle; import java.io.IOException; @@ -39,12 +38,10 @@ import io.bioimage.modelrunner.apposed.appose.Service.Task; import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus; import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; -import io.bioimage.modelrunner.utils.CommonUtils; import net.imglib2.RandomAccessibleInterval; import net.imglib2.type.NativeType; import net.imglib2.type.numeric.RealType; import net.imglib2.util.Cast; -import net.imglib2.util.Util; import net.imglib2.view.Views; /** @@ -209,8 +206,10 @@ public void setDebugging(boolean newState) { void updateImage(RandomAccessibleInterval rai) throws IOException, RuntimeException, InterruptedException { setImageOfInterest(rai); if (img.dimensionsAsLongArray()[0] * img.dimensionsAsLongArray()[1] > MAX_ENCODED_AREA_RS * MAX_ENCODED_AREA_RS - || img.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || img.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) + || img.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || img.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) { + this.targetDims = new long[] {0, 0, 0}; return; + } this.script = ""; sendImgLib2AsNp(); createEncodeImageScript(); @@ -275,10 +274,12 @@ private & NativeType> void sendCropAsNp(long[] cropSiz cropSize = new long[] {encodeCoords[2] - encodeCoords[0], encodeCoords[3] - encodeCoords[1], 3}; else if (cropSize.length == 2) cropSize = new long[] {cropSize[0], cropSize[1], 3}; + else if (cropSize.length == 3 && cropSize[2] != 3) + throw new IllegalArgumentException("The size of the area that wants to be encoded needs to be defined as [width, height]."); else throw new IllegalArgumentException("The size of the area that wants to be encoded needs to be defined as [width, height]."); RandomAccessibleInterval crop = - Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize ); + Views.offsetInterval( Cast.unchecked(img), new long[] {encodeCoords[0], encodeCoords[1], 0}, cropSize ); //RandomAccessibleInterval crop = Views.offsetInterval(crop, new long[] {encodeCoords[1], encodeCoords[0], 0}, interValSize); //RandomAccessibleInterval crop = Views.offsetInterval(Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize); @@ -471,6 +472,8 @@ public List processPoints(List pointsList, List pointsNeg + "contain all the points. Point {x=" + outsideP.get(0)[0] + ", y=" + outsideP.get(0)[1] + "} is out of the region."); } evaluateReencodingNeeded(pointsList, pointsNegList, encodingArea); + pointsList = adaptPointPrompts(pointsList); + pointsNegList = adaptPointPrompts(pointsNegList); this.script = ""; processPointsWithSAM(pointsList.size(), pointsNegList.size(), returnAll); HashMap inputs = new HashMap(); @@ -483,6 +486,16 @@ public List processPoints(List pointsList, List pointsNeg return polys; } + private List adaptPointPrompts(List pointsList) { + pointsList = pointsList.stream().map(pp -> { + int[] newPoint = new int[2]; + newPoint[0] = (int) (pp[0] - this.encodeCoords[0]); + newPoint[1] = (int) (pp[1] - this.encodeCoords[1]); + return newPoint; + }).collect(Collectors.toList()); + return pointsList; + } + /** * Method used that runs EfficientSAM using a bounding box as the prompt. The bounding box should * be a int array of length 4 of the form [x0, y0, x1, y1]. @@ -728,12 +741,12 @@ private Rectangle getApproximateAreaNeeded(List pointsList, List p } minX = (int) Math.max(0, minX - Math.max(focusedArea.width * 0.1, ENCODE_MARGIN)); minY = (int) Math.max(0, minY - Math.max(focusedArea.height * 0.1, ENCODE_MARGIN)); - maxX = (int) Math.min(img.dimensionsAsLongArray()[1], maxX + Math.max(focusedArea.width * 0.1, ENCODE_MARGIN)); - maxY = (int) Math.min(img.dimensionsAsLongArray()[0], maxY + Math.max(focusedArea.height * 0.1, ENCODE_MARGIN)); + maxX = (int) Math.min(img.dimensionsAsLongArray()[0], maxX + Math.max(focusedArea.width * 0.1, ENCODE_MARGIN)); + maxY = (int) Math.min(img.dimensionsAsLongArray()[1], maxY + Math.max(focusedArea.height * 0.1, ENCODE_MARGIN)); Rectangle rect = new Rectangle(); rect.x = minX; rect.y = minY; - rect.width = maxX - minY; + rect.width = maxX - minX; rect.height = maxY - minY; return rect; } @@ -821,8 +834,8 @@ protected long[] calculateEncodingNewCoords(int[] boundingBox, long[] imageSize) long[] posWrtBbox = new long[4]; posWrtBbox[0] = (long) Math.max(0, Math.ceil((boundingBox[0] + xSize / 2) - newSize[0] / 2)); posWrtBbox[1] = (long) Math.max(0, Math.ceil((boundingBox[1] + ySize / 2) - newSize[1] / 2)); - posWrtBbox[2] = (long) Math.min(imageSize[1], Math.floor((boundingBox[2] + xSize / 2) + newSize[0] / 2)); - posWrtBbox[3] = (long) Math.min(imageSize[0], Math.floor((boundingBox[3] + ySize / 2) + newSize[1] / 2)); + posWrtBbox[2] = (long) Math.min(imageSize[0], Math.floor((boundingBox[2] + xSize / 2) + newSize[0] / 2)); + posWrtBbox[3] = (long) Math.min(imageSize[1], Math.floor((boundingBox[3] + ySize / 2) + newSize[1] / 2)); return posWrtBbox; } diff --git a/src/main/java/ai/nets/samj/models/EfficientSamJ.java b/src/main/java/ai/nets/samj/models/EfficientSamJ.java index 1aecef3..ae7c5ca 100644 --- a/src/main/java/ai/nets/samj/models/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientSamJ.java @@ -213,6 +213,7 @@ protected void createEncodeImageScript() { code += ll + ", "; code = code.substring(0, code.length() - 2); code += "])" + System.lineSeparator(); + //code += "np.save('/home/carlos/git/crop.npy', im)" + System.lineSeparator(); code += "input_h = im.shape[1]" + System.lineSeparator(); code += "input_w = im.shape[0]" + System.lineSeparator(); code += "globals()['input_h'] = input_h" + System.lineSeparator(); @@ -360,7 +361,7 @@ void adaptImageToModel(final RandomAccessibleInterval ogImg, RandomAccessible } else if (ogImg.numDimensions() == 3 && ogImg.dimensionsAsLongArray()[2] == 1) { debugPrinter.printText("CONVERTED 1 CHANNEL IMAGE INTO 3 TO BE FEEDED TO SAMJ"); IntervalView resIm = Views.interval( Views.expandMirrorDouble(ImgLib2Utils.normalizedView(ogImg, this.debugPrinter), new long[] {0, 0, 2}), - Intervals.createMinMax(new long[] {0, 0, 0, ogImg.dimensionsAsLongArray()[0], ogImg.dimensionsAsLongArray()[1], 2}) ); + Intervals.createMinMax(new long[] {0, 0, 0, ogImg.dimensionsAsLongArray()[0] - 1, ogImg.dimensionsAsLongArray()[1] - 1, 2}) ); RealTypeConverters.copyFromTo( resIm, targetImg ); } else if (ogImg.numDimensions() == 2) { adaptImageToModel(Views.addDimension(ogImg, 0, 0), targetImg); diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index 5bb9183..611f8a4 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -457,7 +457,7 @@ void adaptImageToModel(RandomAccessibleInterval ogImg, RandomAccessibleInterv debugPrinter.printText("CONVERTED 1 CHANNEL IMAGE INTO 3 TO BE FEEDED TO SAMJ"); IntervalView resIm = Views.interval( Views.expandMirrorDouble(ImgLib2Utils.convertViewToRGB(ogImg, this.debugPrinter), new long[] {0, 0, 2}), - Intervals.createMinMax(new long[] {0, 0, 0, ogImg.dimensionsAsLongArray()[0], ogImg.dimensionsAsLongArray()[1], 2}) ); + Intervals.createMinMax(new long[] {0, 0, 0, ogImg.dimensionsAsLongArray()[0] - 1, ogImg.dimensionsAsLongArray()[1] - 1, 2}) ); RealTypeConverters.copyFromTo( resIm, targetImg ); } else if (ogImg.numDimensions() == 2) { adaptImageToModel(Views.addDimension(ogImg, 0, 0), targetImg);