From 40e55411aa307c432b1657e72f3438a42b73c6ce Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Fri, 19 Apr 2024 14:28:10 +0200 Subject: [PATCH] first try to see if it works --- src/main/java/ai/nets/samj/AbstractSamJ.java | 22 +++++ src/main/java/ai/nets/samj/EfficientSamJ.java | 82 ++++++++++++++++--- 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/src/main/java/ai/nets/samj/AbstractSamJ.java b/src/main/java/ai/nets/samj/AbstractSamJ.java index cb9b6a2..a6ec747 100644 --- a/src/main/java/ai/nets/samj/AbstractSamJ.java +++ b/src/main/java/ai/nets/samj/AbstractSamJ.java @@ -32,7 +32,10 @@ import net.imglib2.util.Util; import net.imglib2.view.Views; +import java.awt.Polygon; import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.List; /** * Class that contains methods that can be sued by SAMJ models @@ -270,4 +273,23 @@ protected long[] calculateEncodingNewCoords(int[] boundingBox, long[] imageSize) posWrtBbox[3] = boundingBox[3] + (long) Math.floor(newSize[1] / 2); return posWrtBbox; } + + /** + * Method that recalculates the coordinates of the polygons outputed by SAMJ. + * + * This method is usually for big images. In order to create encoding with enough resolution + * to detect small objects compared to the size of the whole image, SAMJ might encode crops of + * the total image, thus the coordinates of the polygons obtained need to be shifted in order + * to match the original image. + * @param polys + * polys obtained by SAMJ on the encoded crop + * @param encodeCoords + * position of the crop in the total image + */ + protected void recalculatePolys(List polys, long[] encodeCoords) { + polys.stream().forEach(pp -> { + pp.xpoints = Arrays.stream(pp.xpoints).map(x -> x + (int) encodeCoords[0]).toArray(); + pp.ypoints = Arrays.stream(pp.ypoints).map(y -> y + (int) encodeCoords[0]).toArray(); + }); + } } diff --git a/src/main/java/ai/nets/samj/EfficientSamJ.java b/src/main/java/ai/nets/samj/EfficientSamJ.java index 2ce01a3..5d653c4 100644 --- a/src/main/java/ai/nets/samj/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/EfficientSamJ.java @@ -42,6 +42,7 @@ import net.imglib2.type.numeric.RealType; import net.imglib2.type.numeric.integer.UnsignedByteType; import net.imglib2.type.numeric.real.FloatType; +import net.imglib2.util.Cast; import net.imglib2.util.Intervals; import net.imglib2.view.IntervalView; import net.imglib2.view.Views; @@ -80,7 +81,7 @@ public class EfficientSamJ extends AbstractSamJ implements AutoCloseable { * Usually the vertex is at 0,0 and the encoded image is all the pixels. This feature is useful for when the image * is big and reeconding needs to happen while the user pans and zooms in the image. */ - private long[] encodeCoords; + private long[] encodeCoords = new long[] {0, 0, 0}; /** * Scale factor of x and y applied to the image that is going to be annotated. * The image of interest does not need to be encoded normally. However, it is optimal to scale big images @@ -116,11 +117,6 @@ public class EfficientSamJ extends AbstractSamJ implements AutoCloseable { + "globals()['np'] = np" + System.lineSeparator() + "globals()['torch'] = torch" + System.lineSeparator() + "globals()['predictor'] = predictor" + System.lineSeparator(); - /** - * String containing the Python imports code after it has been formatted with the correct - * paths and names - */ - private String IMPORTS_FORMATED; /** * Create an instance of the class to be able to run EfficientSAM in Java. @@ -162,7 +158,7 @@ private EfficientSamJ(SamEnvManager manager, }; python = env.python(); python.debug(debugPrinter::printText); - IMPORTS_FORMATED = String.format(IMPORTS, + String IMPORTS_FORMATED = String.format(IMPORTS, manager.getEfficientSamEnv() + File.separator + SamEnvManager.ESAM_NAME, manager.getEfficientSAMSmallWeightsPath()); printScript(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES, "Edges tracing code"); @@ -301,6 +297,34 @@ else if (task.status == TaskStatus.CRASHED) } } + private void reencodeCrop() throws IOException, InterruptedException, RuntimeException { + this.script = ""; + sendCropAsNp(); + this.script += "" + + "task.update(str(im.shape))" + System.lineSeparator() + + "aa = predictor.get_image_embeddings(im[None, ...])"; + try { + printScript(script, "Creation of the cropped embeddings"); + Task task = python.task(script); + task.waitFor(); + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException(); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(); + this.shma.close(); + } catch (IOException | InterruptedException | RuntimeException e) { + try { + this.shma.close(); + } catch (IOException e1) { + throw new IOException(e.toString() + System.lineSeparator() + e1.toString()); + } + throw e; + } + + } + private List processAndRetrieveContours(HashMap inputs) throws IOException, RuntimeException, InterruptedException { Map results = null; @@ -596,17 +620,22 @@ public List processBox(int[] boundingBox) */ public List processBox(int[] boundingBox, boolean returnAll) throws IOException, RuntimeException, InterruptedException { + int[] adaptedBoundingBox = boundingBox; if (needsMoreResolution(boundingBox)) { long[] cropPosWrtBbox = calculateEncodingNewCoords(boundingBox, this.img.dimensionsAsLongArray()); - reencodeCrop(cropPosWrtBbox); + this.encodeCoords = new long[] {boundingBox[0] + cropPosWrtBbox[0], boundingBox[1] + cropPosWrtBbox[1], + boundingBox[2] + cropPosWrtBbox[2], boundingBox[3] + cropPosWrtBbox[3]}; + adaptedBoundingBox = new int[] {(int) -cropPosWrtBbox[0], (int) -cropPosWrtBbox[1], + (int) (boundingBox[2] + cropPosWrtBbox[2]), (int) (boundingBox[3] + cropPosWrtBbox[3])}; + reencodeCrop(); } - boundingBox = recalculateBbox(boundingBox); this.script = ""; processBoxWithSAM(returnAll); HashMap inputs = new HashMap(); - inputs.put("input_box", boundingBox); + inputs.put("input_box", adaptedBoundingBox); printScript(script, "Rectangle inference"); List polys = processAndRetrieveContours(inputs); + recalculatePolys(polys, encodeCoords); debugPrinter.printText("processBox() obtained " + polys.size() + " polygons"); return polys; } @@ -660,7 +689,38 @@ void sendImgLib2AsNp(RandomAccessibleInterval targetImg) { // This line wants to recreate the original numpy array. Should look like: // input0_appose_shm = shared_memory.SharedMemory(name=input0) // input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64]) - code += IMPORTS_FORMATED+"im_shm = shared_memory.SharedMemory(name='" + code += "im_shm = shared_memory.SharedMemory(name='" + + shma.getNameForPython() + "', size=" + shma.getSize() + + ")" + System.lineSeparator(); + int size = 1; + for (long l : targetDims) {size *= l;} + code += "im = np.ndarray(" + size + ", dtype='float32', buffer=im_shm.buf).reshape(["; + for (long ll : targetDims) + code += ll + ", "; + code = code.substring(0, code.length() - 2); + code += "])" + System.lineSeparator(); + code += "input_h = im.shape[0]" + System.lineSeparator(); + code += "input_w = im.shape[1]" + System.lineSeparator(); + code += "globals()['input_h'] = input_h" + System.lineSeparator(); + code += "globals()['input_w'] = input_w" + System.lineSeparator(); + code += "im = torch.from_numpy(np.transpose(im.astype('float32'), (2, 0, 1)))" + System.lineSeparator(); + code += "im_shm.unlink()" + System.lineSeparator(); + //code += "box_shm.close()" + System.lineSeparator(); + this.script += code; + } + + private & NativeType> + void sendCropAsNp() { + long[] intervalSize = new long[] {encodeCoords[3] - encodeCoords[1], encodeCoords[2] - encodeCoords[0], img.dimensionsAsLongArray()[2]}; + RandomAccessibleInterval crop = Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, intervalSize ); + + shma = createEfficientSAMInputSHM(reescaleIfNeeded(crop)); + adaptImageToModel(crop, shma.getSharedRAI()); + String code = ""; + // This line wants to recreate the original numpy array. Should look like: + // input0_appose_shm = shared_memory.SharedMemory(name=input0) + // input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64]) + code += "im_shm = shared_memory.SharedMemory(name='" + shma.getNameForPython() + "', size=" + shma.getSize() + ")" + System.lineSeparator(); int size = 1;