From e89d862452e2deef6e0f02dc70636b68c9f4b476 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Tue, 23 Apr 2024 16:33:58 +0200 Subject: [PATCH] adapt efficientSAMJ to new modularity --- src/main/java/ai/nets/samj/EfficientSamJ.java | 775 +----------------- 1 file changed, 39 insertions(+), 736 deletions(-) diff --git a/src/main/java/ai/nets/samj/EfficientSamJ.java b/src/main/java/ai/nets/samj/EfficientSamJ.java index 6d38579..499a5e9 100644 --- a/src/main/java/ai/nets/samj/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/EfficientSamJ.java @@ -55,48 +55,7 @@ * @author Carlos Garcia * @author vladimir Ulman */ -public class EfficientSamJ extends AbstractSamJ implements AutoCloseable { - /** - * Instance referencing the Python environment that is going to be used to run EfficientSAM - */ - private final Environment env; - /** - * Instance of {@link Service} that is in charge of opening a Python process and running the - * scripts provided in that Python process in order to be able to use EfficientSAM - */ - private final Service python; - /** - * The scripts that want to be run in Python - */ - private String script = ""; - /** - * Shared memory array used to share between Java and Python the image that wants to be processed by EfficientSAM - */ - private SharedMemoryArray shma; - /** - * Target dimensions of the image that is going to be encoded. If a single-channel 2D image is provided, that image is - * converted into a 3-channel image that EfficientSAM requires - */ - private long[] targetDims; - /** - * Coordinates of the vertex of the crop/zoom of hte image of interest that has been encoded. - * It is the closest vertex to the origin. - * Usually the vertex is at 0,0 and the encoded image is all the pixels. This feature is useful for when the image - * is big and reeconding needs to happen while the user pans and zooms in the image. - */ - private long[] encodeCoords; - /** - * Scale factor of x and y applied to the image that is going to be annotated. - * The image of interest does not need to be encoded normally. However, it is optimal to scale big images - * as the resolution of the segmentation depends on the ratio between the size of the image and the size of - * the object, thus - */ - private double[] scale; - /** - * Complete image being annotated. Usually this image is encoded completely - * but for larger images, zooms of it might be encoded instead of the whole image - */ - private RandomAccessibleInterval img; +public class EfficientSamJ extends AbstractSamJ2 { /** * All the Python imports and configurations needed to start using EfficientSAM. */ @@ -205,7 +164,7 @@ else if (task.status == TaskStatus.CRASHED) try{ sam = new EfficientSamJ(manager, debugPrinter, printPythonCode); sam.encodeCoords = new long[] {0, 0}; - sam.addImage(image); + sam.updateImage(image); } catch (IOException | RuntimeException | InterruptedException ex) { if (sam != null) sam.close(); throw ex; @@ -237,7 +196,7 @@ else if (task.status == TaskStatus.CRASHED) try{ sam = new EfficientSamJ(manager); sam.encodeCoords = new long[] {0, 0}; - sam.addImage(image); + sam.updateImage(image); sam.img = image; } catch (IOException | RuntimeException | InterruptedException ex) { if (sam != null) sam.close(); @@ -245,209 +204,16 @@ else if (task.status == TaskStatus.CRASHED) } return sam; } - - /** - * Change the image encoded by the EfficientSAM model - * @param - * ImgLib2 data type of the image of interest - * @param rai - * image (n-dimensional array) that is going to be encoded as a {@link RandomAccessibleInterval} - * @throws IOException if any of the files to run a Python process is missing - * @throws RuntimeException if there is any error running the Python code - * @throws InterruptedException if the process is interrupted - */ - public & NativeType> - void updateImage(RandomAccessibleInterval rai) throws IOException, RuntimeException, InterruptedException { - addImage(rai); - this.img = rai; - } - - /** - * Encode an image (n-dimensional array) with an EfficientSAM model - * @param - * ImgLib2 data type of the image of interest - * @param rai - * image (n-dimensional array) that is going to be encoded as a {@link RandomAccessibleInterval} - * @throws IOException if any of the files to run a Python process is missing - * @throws RuntimeException if there is any error running the Python code - * @throws InterruptedException if the process is interrupted - */ - private & NativeType> - void addImage(RandomAccessibleInterval rai) - throws IOException, RuntimeException, InterruptedException { - if (rai.dimensionsAsLongArray()[0] * rai.dimensionsAsLongArray()[1] > MAX_ENCODED_AREA_RS * MAX_ENCODED_AREA_RS - || rai.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || rai.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) { - this.targetDims = new long[] {0, 0, 0}; - this.img = rai; - return; - } - this.script = ""; - sendImgLib2AsNp(rai); - this.script += "" - + "task.update(str(im.shape))" + System.lineSeparator() - + "aa = predictor.get_image_embeddings(im[None, ...])"; - try { - printScript(script, "Creation of initial embeddings"); - Task task = python.task(script); - task.waitFor(); - if (task.status == TaskStatus.CANCELED) - throw new RuntimeException(); - else if (task.status == TaskStatus.FAILED) - throw new RuntimeException(); - else if (task.status == TaskStatus.CRASHED) - throw new RuntimeException(); - this.shma.close(); - } catch (IOException | InterruptedException | RuntimeException e) { - try { - this.shma.close(); - } catch (IOException e1) { - throw new IOException(e.toString() + System.lineSeparator() + e1.toString()); - } - throw e; - } - } - - private void reencodeCrop() throws IOException, InterruptedException, RuntimeException { - reencodeCrop(null); - } - - private void reencodeCrop(long[] cropSize) throws IOException, InterruptedException, RuntimeException { - this.script = ""; - sendCropAsNp(cropSize); - this.script += "" + + @Override + protected void createEncodeImageScript() { + this.script = "" + "task.update(str(im.shape))" + System.lineSeparator() + "aa = predictor.get_image_embeddings(im[None, ...])"; - try { - printScript(script, "Creation of the cropped embeddings"); - Task task = python.task(script); - task.waitFor(); - if (task.status == TaskStatus.CANCELED) - throw new RuntimeException(); - else if (task.status == TaskStatus.FAILED) - throw new RuntimeException(); - else if (task.status == TaskStatus.CRASHED) - throw new RuntimeException(); - this.shma.close(); - } catch (IOException | InterruptedException | RuntimeException e) { - try { - this.shma.close(); - } catch (IOException e1) { - throw new IOException(e.toString() + System.lineSeparator() + e1.toString()); - } - throw e; - } - } - - private List processAndRetrieveContours(HashMap inputs) - throws IOException, RuntimeException, InterruptedException { - Map results = null; - try { - Task task = python.task(script, inputs); - task.waitFor(); - if (task.status == TaskStatus.CANCELED) - throw new RuntimeException(); - else if (task.status == TaskStatus.FAILED) - throw new RuntimeException(); - else if (task.status == TaskStatus.CRASHED) - throw new RuntimeException(); - else if (task.status != TaskStatus.COMPLETE) - throw new RuntimeException(); - else if (task.outputs.get("contours_x") == null) - throw new RuntimeException(); - else if (task.outputs.get("contours_y") == null) - throw new RuntimeException(); - results = task.outputs; - } catch (IOException | InterruptedException | RuntimeException e) { - try { - this.shma.close(); - } catch (IOException e1) { - throw new IOException(e.toString() + System.lineSeparator() + e1.toString()); - } - throw e; - } - final List> contours_x_container = (List>)results.get("contours_x"); - final Iterator> contours_x = contours_x_container.iterator(); - final Iterator> contours_y = ((List>)results.get("contours_y")).iterator(); - final List polys = new ArrayList<>(contours_x_container.size()); - while (contours_x.hasNext()) { - int[] xArr = contours_x.next().stream().mapToInt(Number::intValue).toArray(); - int[] yArr = contours_y.next().stream().mapToInt(Number::intValue).toArray(); - polys.add( new Polygon(xArr, yArr, xArr.length) ); - } - return polys; - } - - /** - * Method used that runs EfficientSAM using a mask as the prompt. The mask should be a 2D single-channel - * image {@link RandomAccessibleInterval} of the same x and y sizes as the image of interest, the image - * where the model is finding the segmentations. - * Note that the quality of this prompting method is not good, it is still experimental as it barely works. - * It returns a list of polygons that corresponds to the contours of the masks found by EfficientSAM. - * - * @param - * ImgLib2 datatype of the mask - * @param img - * mask used as the prompt - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public & NativeType> - List processMask(RandomAccessibleInterval img) - throws IOException, RuntimeException, InterruptedException { - return processMask(img, true); - } - - /** - * Method used that runs EfficientSAM using a mask as the prompt. The mask should be a 2D single-channel - * image {@link RandomAccessibleInterval} of the same x and y sizes as the image of interest, the image - * where the model is finding the segmentations. - * Note that the quality of this prompting method is not good, it is still experimental as it barely works - * - * @param - * ImgLib2 datatype of the mask - * @param img - * mask used as the prompt - * @param returnAll - * whether to return all the polygons created by EfficientSAM of only the biggest - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public & NativeType> - List processMask(RandomAccessibleInterval img, boolean returnAll) - throws IOException, RuntimeException, InterruptedException { - long[] dims = img.dimensionsAsLongArray(); - if (dims.length == 2 && dims[1] == this.shma.getOriginalShape()[1] && dims[0] == this.shma.getOriginalShape()[0]) { - img = Views.permute(img, 0, 1); - } else if (dims.length != 2 && dims[0] != this.shma.getOriginalShape()[1] && dims[1] != this.shma.getOriginalShape()[0]) { - throw new IllegalArgumentException("The provided mask should be a 2d image with just one channel of width " - + this.shma.getOriginalShape()[1] + " and height " + this.shma.getOriginalShape()[0]); - } - SharedMemoryArray maskShma = SharedMemoryArray.buildSHMA(img); - try { - return processMask(maskShma, returnAll); - } catch (IOException | RuntimeException | InterruptedException ex) { - maskShma.close(); - throw ex; - } - } - - private List processMask(SharedMemoryArray shmArr, boolean returnAll) - throws IOException, RuntimeException, InterruptedException { - this.script = ""; - processMasksWithSam(shmArr, returnAll); - printScript(script, "Pre-computed mask inference"); - List polys = processAndRetrieveContours(null); - debugPrinter.printText("processMask() obtained " + polys.size() + " polygons"); - return polys; - } - - private void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) { + @Override + protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) { String code = ""; code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; @@ -499,485 +265,9 @@ private void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) { code += "shm_mask.unlink()" + System.lineSeparator(); this.script = code; } - - /** - * Method used that runs EfficientSAM using a list of points as the prompt. This method runs - * the prompt encoder and the EfficientSAM decoder only, the image encoder was run when the model - * was initialized with the image, thus it is quite fast. - * It returns a list of polygons that corresponds to the contours of the masks found by EfficientSAM - * @param pointsList - * the list of points that serve as a prompt for EfficientSAM. Each point is an int array - * of length 2, first position is x-axis, second y-axis - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public List processPoints(List pointsList) - throws IOException, RuntimeException, InterruptedException{ - return processPoints(pointsList, true); - } - - /** - * Method used that runs EfficientSAM using a list of points as the prompt. This method runs - * the prompt encoder and the EfficientSAM decoder only, the image encoder was run when the model - * was initialized with the image, thus it is quite fast. - * It returns a list of polygons that corresponds to the contours of the masks found by EfficientSAM - * @param pointsList - * the list of points that serve as a prompt for EfficientSAM. Each point is an int array - * of length 2, first position is x-axis, second y-axis - * @param returnAll - * whether to return all the polygons created by EfficientSAM of only the biggest - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public List processPoints(List pointsList, boolean returnAll) - throws IOException, RuntimeException, InterruptedException{ - Rectangle rect = new Rectangle(); - rect.x = -1; - rect.y = -1; - rect.height = -1; - rect.width = -1; - return processPoints(pointsList, rect, returnAll); - } - - public List processPoints(List pointsList, Rectangle encodingArea, boolean returnAll) - throws IOException, RuntimeException, InterruptedException { - Objects.requireNonNull(encodingArea, "Second argument cannot be null. Use the method " - + "'processPoints(List pointsList, Rectangle zoomedArea, boolean returnAll)'" - + " instead"); - return processPoints(pointsList, new ArrayList(), encodingArea, returnAll); - } - - /** - * Method used that runs EfficientSAM using a list of points as the prompt. This method also accepts another - * list of points as the negative prompt, the points that represent the background class wrt the object of interest. This method runs - * the prompt encoder and the EfficientSAM decoder only, the image encoder was run when the model - * was initialized with the image, thus it is quite fast. - * It returns a list of polygons that corresponds to the contours of the masks found by EfficientSAM - * @param pointsList - * the list of points that serve as a prompt for EfficientSAM. Each point is an int array - * of length 2, first position is x-axis, second y-axis - * @param pointsNegList - * the list of points that does not point to the instance of interest, but the background - * @param returnAll - * whether to return all the polygons created by EfficientSAM of only the biggest - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public List processPoints(List pointsList, List pointsNegList, - Rectangle encodingArea, boolean returnAll) - throws IOException, RuntimeException, InterruptedException { - Objects.requireNonNull(encodingArea, "Third argument cannot be null. Use the method " - + "'processPoints(List pointsList, List pointsNegList, Rectangle zoomedArea, boolean returnAll)'" - + " instead"); - - if (encodingArea.x == -1) { - encodingArea = getCurrentlyEncodedArea(); - } else { - ArrayList outsideP = getPointsNotInRect(pointsList, pointsNegList, encodingArea); - if (outsideP.size() != 0) - throw new IllegalArgumentException("The Rectangle containing the area to be encoded should " - + "contain all the points. Point {x=" + outsideP.get(0)[0] + ", y=" + outsideP.get(0)[1] + "} is out of the region."); - } - evaluateReencodingNeeded(pointsList, pointsNegList, encodingArea); - this.script = ""; - processPointsWithSAM(pointsList.size(), pointsNegList.size(), returnAll); - HashMap inputs = new HashMap(); - inputs.put("input_points", pointsList); - inputs.put("input_neg_points", pointsNegList); - printScript(script, "Points and negative points inference"); - List polys = processAndRetrieveContours(inputs); - recalculatePolys(polys, encodeCoords); - debugPrinter.printText("processPoints() obtained " + polys.size() + " polygons"); - return polys; - } - - private ArrayList getPointsNotInRect(List pointsList, List pointsNegList, Rectangle encodingArea) { - ArrayList points = new ArrayList(); - ArrayList not = new ArrayList(); - points.addAll(pointsNegList); - points.addAll(pointsList); - for (int[] pp : points) { - if (!encodingArea.contains(pp[0], pp[1])) - not.add(pp); - } - return not; - } - - public Rectangle getCurrentlyEncodedArea() { - int xMargin = (int) (targetDims[1] * 0.1); - int yMargin = (int) (targetDims[0] * 0.1); - Rectangle alreadyEncoded; - if (encodeCoords[0] != 0 || encodeCoords[1] != 0 || targetDims[1] != this.img.dimensionsAsLongArray()[1] - || targetDims[0] != this.img.dimensionsAsLongArray()[0]) { - alreadyEncoded = new Rectangle((int) encodeCoords[0] + xMargin / 2, (int) encodeCoords[1] + yMargin / 2, - (int) targetDims[1] - xMargin, (int) targetDims[0] - yMargin); - } else { - alreadyEncoded = new Rectangle((int) encodeCoords[0], (int) encodeCoords[1], - (int) targetDims[1], (int) targetDims[0]); - } - return alreadyEncoded; - } - - /** - * TODO Explain reencoding logic - * TODO Explain reencoding logic - * TODO Explain reencoding logic - * TODO Explain reencoding logic - * TODO Explain reencoding logic - * @param pointsList - * @param pointsNegList - * @param rect - * @throws IOException - * @throws InterruptedException - * @throws RuntimeException - */ - private void evaluateReencodingNeeded(List pointsList, List pointsNegList, Rectangle rect) - throws IOException, InterruptedException, RuntimeException { - Rectangle alreadyEncoded = getCurrentlyEncodedArea(); - Rectangle neededArea = getApproximateAreaNeeded(pointsList, pointsNegList, rect); - ArrayList notInRect = getPointsNotInRect(pointsList, pointsNegList, rect); - if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y - && alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x - && alreadyEncoded.height + alreadyEncoded.y >= rect.width + rect.y - && alreadyEncoded.width * 0.9 < rect.width && alreadyEncoded.height * 0.9 < rect.height - && notInRect.size() == 0) { - return; - } else if (notInRect.size() != 0) { - this.encodeCoords = new long[] {rect.x, rect.y}; - this.reencodeCrop(new long[] {rect.width, rect.height}); - } else if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y - && alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x - && alreadyEncoded.height + alreadyEncoded.y >= rect.width + rect.y - && (alreadyEncoded.width * 0.9 > rect.width || alreadyEncoded.height * 0.9 > rect.height)) { - this.encodeCoords = new long[] {rect.x, rect.y}; - this.reencodeCrop(new long[] {rect.width, rect.height}); - } else if (alreadyEncoded.x <= neededArea.x && alreadyEncoded.y <= neededArea.y - && alreadyEncoded.width + alreadyEncoded.x >= neededArea.width + neededArea.x - && alreadyEncoded.height + alreadyEncoded.y >= neededArea.width + neededArea.y - && alreadyEncoded.width * 0.9 < neededArea.width && alreadyEncoded.height * 0.9 < neededArea.height - && notInRect.size() == 0) { - return; - } else { - this.encodeCoords = new long[] {rect.x, rect.y}; - this.reencodeCrop(new long[] {rect.width, rect.height}); - } - } - - private Rectangle getApproximateAreaNeeded(List pointsList, List pointsNegList) { - ArrayList points = new ArrayList(); - points.addAll(pointsNegList); - points.addAll(pointsList); - int minY = Integer.MAX_VALUE; - int minX = Integer.MAX_VALUE; - int maxY = 0; - int maxX = 0; - for (int[] pp : points) { - if (pp[0] < minX) - minX = pp[0]; - if (pp[0] > maxX) - maxX = pp[0]; - if (pp[1] < minY) - minY = pp[1]; - if (pp[1] > maxY) - maxY = pp[1]; - } - minX = (int) Math.max(0, minX - Math.max((maxX - minX) * 0.1, ENCODE_MARGIN)); - minY = (int) Math.max(0, minY - Math.max((maxY - minY) * 0.1, ENCODE_MARGIN)); - Rectangle rect = new Rectangle(); - rect.x = minX; - rect.y = minY; - rect.width = maxX - minY; - rect.height = maxY - minY; - return rect; - } - - private Rectangle getApproximateAreaNeeded(List pointsList, List pointsNegList, Rectangle focusedArea) { - ArrayList points = new ArrayList(); - points.addAll(pointsNegList); - points.addAll(pointsList); - int minY = Integer.MAX_VALUE; - int minX = Integer.MAX_VALUE; - int maxY = 0; - int maxX = 0; - for (int[] pp : points) { - if (pp[0] < minX) - minX = pp[0]; - if (pp[0] > maxX) - maxX = pp[0]; - if (pp[1] < minY) - minY = pp[1]; - if (pp[1] > maxY) - maxY = pp[1]; - } - minX = (int) Math.max(0, minX - Math.max(focusedArea.width * 0.1, ENCODE_MARGIN)); - minY = (int) Math.max(0, minY - Math.max(focusedArea.height * 0.1, ENCODE_MARGIN)); - maxX = (int) Math.min(img.dimensionsAsLongArray()[1], maxX + Math.max(focusedArea.width * 0.1, ENCODE_MARGIN)); - maxY = (int) Math.min(img.dimensionsAsLongArray()[0], maxY + Math.max(focusedArea.height * 0.1, ENCODE_MARGIN)); - Rectangle rect = new Rectangle(); - rect.x = minX; - rect.y = minY; - rect.width = maxX - minY; - rect.height = maxY - minY; - return rect; - } - - /** - * Method used that runs EfficientSAM using a list of points as the prompt. This method also accepts another - * list of points as the negative prompt, the points that represent the background class wrt the object of interest. This method runs - * the prompt encoder and the EfficientSAM decoder only, the image encoder was run when the model - * was initialized with the image, thus it is quite fast. - * It returns a list of polygons that corresponds to the contours of the masks found by EfficientSAM - * @param pointsList - * the list of points that serve as a prompt for EfficientSAM. Each point is an int array - * of length 2, first position is x-axis, second y-axis - * @param pointsNegList - * the list of points that does not point to the instance of interest, but the background - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public List processPoints(List pointsList, List pointsNegList) - throws IOException, RuntimeException, InterruptedException { - Rectangle rect = new Rectangle(); - rect.x = (int) this.encodeCoords[0]; - rect.y = (int) this.encodeCoords[1]; - rect.height = (int) this.targetDims[0]; - rect.width = (int) this.targetDims[1]; - return processPoints(pointsList, pointsNegList, rect, true); - } - - public List processPoints(List pointsList, List pointsNegList, - Rectangle zoomedArea) - throws IOException, RuntimeException, InterruptedException { - return processPoints(pointsList, pointsNegList, zoomedArea, true); - } - - /** - * Method used that runs EfficientSAM using a list of points as the prompt. This method also accepts another - * list of points as the negative prompt, the points that represent the background class wrt the object of interest. This method runs - * the prompt encoder and the EfficientSAM decoder only, the image encoder was run when the model - * was initialized with the image, thus it is quite fast. - * It returns a list of polygons that corresponds to the contours of the masks found by EfficientSAM - * @param pointsList - * the list of points that serve as a prompt for EfficientSAM. Each point is an int array - * of length 2, first position is x-axis, second y-axis - * @param pointsNegList - * the list of points that does not point to the instance of interest, but the background - * @param returnAll - * whether to return all the polygons created by EfficientSAM of only the biggest - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public List processPoints(List pointsList, List pointsNegList, boolean returnAll) - throws IOException, RuntimeException, InterruptedException { - Rectangle rect = new Rectangle(); - rect.x = (int) this.encodeCoords[0]; - rect.y = (int) this.encodeCoords[1]; - rect.height = (int) this.targetDims[0]; - rect.width = (int) this.targetDims[1]; - return processPoints(pointsList, pointsNegList, rect, returnAll); - } - - /** - * Method used that runs EfficientSAM using a bounding box as the prompt. The bounding box should - * be a int array of length 4 of the form [x0, y0, x1, y1]. - * This method runs the prompt encoder and the EfficientSAM decoder only, the image encoder was run when the model - * was initialized with the image, thus it is quite fast. - * - * Returns a list of all the polygons found by EfficientSAM - * - * @param boundingBox - * the bounding box that serves as the prompt for EfficientSAM - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public List processBox(int[] boundingBox) - throws IOException, RuntimeException, InterruptedException { - return processBox(boundingBox, true); - } - - /** - * Method used that runs EfficientSAM using a bounding box as the prompt. The bounding box should - * be a int array of length 4 of the form [x0, y0, x1, y1]. - * This method runs the prompt encoder and the EfficientSAM decoder only, the image encoder was run when the model - * was initialized with the image, thus it is quite fast. - * - * @param boundingBox - * the bounding box that serves as the prompt for EfficientSAM - * @param returnAll - * whether to return all the polygons created by EfficientSAM of only the biggest - * @return a list of polygons where each polygon is the contour of a mask that has been found by EfficientSAM - * @throws IOException if any of the files needed to run the Python script is missing - * @throws RuntimeException if there is any error running the Python process - * @throws InterruptedException if the process in interrupted - */ - public List processBox(int[] boundingBox, boolean returnAll) - throws IOException, RuntimeException, InterruptedException { - if (needsMoreResolution(boundingBox)) { - this.encodeCoords = calculateEncodingNewCoords(boundingBox, this.img.dimensionsAsLongArray()); - reencodeCrop(); - } else if (!isAreaEncoded(boundingBox)) { - this.encodeCoords = calculateEncodingNewCoords(boundingBox, this.img.dimensionsAsLongArray()); - reencodeCrop(); - } - int[] adaptedBoundingBox = new int[] {(int) (boundingBox[0] - encodeCoords[0]), (int) (boundingBox[1] - encodeCoords[1]), - (int) (boundingBox[2] - encodeCoords[0]), (int) (boundingBox[3] - encodeCoords[1])};; - this.script = ""; - processBoxWithSAM(returnAll); - HashMap inputs = new HashMap(); - inputs.put("input_box", adaptedBoundingBox); - printScript(script, "Rectangle inference"); - List polys = processAndRetrieveContours(inputs); - recalculatePolys(polys, encodeCoords); - debugPrinter.printText("processBox() obtained " + polys.size() + " polygons"); - return polys; - } - - /** - * Check whether the bounding box is inside the area that is encoded or not - * @param boundingBox - * the vertices of the bounding box - * @return whether the bounding box is within the encoded area or not - */ - public boolean isAreaEncoded(int[] boundingBox) { - boolean upperLeftVertex = (boundingBox[0] > this.encodeCoords[0]) && (boundingBox[0] < this.encodeCoords[2]); - boolean upperRightVertex = (boundingBox[2] > this.encodeCoords[0]) && (boundingBox[2] < this.encodeCoords[2]); - boolean downLeftVertex = (boundingBox[1] > this.encodeCoords[1]) && (boundingBox[1] < this.encodeCoords[3]); - boolean downRightVertex = (boundingBox[3] > this.encodeCoords[1]) && (boundingBox[3] < this.encodeCoords[3]); - - if (upperLeftVertex && upperRightVertex && downLeftVertex && downRightVertex) - return true; - return false; - } - - /** - * For bounding box masks, check whether the its size is too small compared to the size - * of the encoded image. - * - * Approximately, if the original image encoded is about 20 times bigger than the bounding box size, - * the resolution of the SAM-based model encodings will not be enough to identify the object of interest, - * thus re-encoding of a zoomed part of the image will be necessary. - * - * @param boundingBox - * bounding box of interest - * @return whether the bounding box of interest is big enough to produce good results or not - */ - public boolean needsMoreResolution(int[] boundingBox) { - long xSize = boundingBox[2] - boundingBox[0]; - long ySize = boundingBox[3] - boundingBox[1]; - long encodedX = targetDims[1]; - long encodedY = targetDims[0]; - if (xSize * LOWER_REENCODE_THRESH < encodedX && ySize * LOWER_REENCODE_THRESH < encodedY) - return true; - return false; - } - - /** - * TODO what to do, is there a bounding box that is too big with respect to the encoded crop? - * @param boundingBox - * @return - */ - public boolean boundingBoxTooBig(int[] boundingBox) { - long xSize = boundingBox[2] - boundingBox[0]; - long ySize = boundingBox[3] - boundingBox[1]; - long encodedX = targetDims[1]; - long encodedY = targetDims[0]; - if (xSize * UPPER_REENCODE_THRESH > encodedX && ySize * UPPER_REENCODE_THRESH > encodedY) - return true; - return false; - } @Override - /** - * {@inheritDoc} - * Close the Python process and clean the memory - */ - public void close() { - if (python != null) python.close(); - } - - private & NativeType> - void sendImgLib2AsNp(RandomAccessibleInterval targetImg) { - shma = createEfficientSAMInputSHM(reescaleIfNeeded(targetImg)); - adaptImageToModel(targetImg, shma.getSharedRAI()); - String code = ""; - // This line wants to recreate the original numpy array. Should look like: - // input0_appose_shm = shared_memory.SharedMemory(name=input0) - // input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64]) - code += "im_shm = shared_memory.SharedMemory(name='" - + shma.getNameForPython() + "', size=" + shma.getSize() - + ")" + System.lineSeparator(); - int size = 1; - for (long l : targetDims) {size *= l;} - code += "im = np.ndarray(" + size + ", dtype='float32', buffer=im_shm.buf).reshape(["; - for (long ll : targetDims) - code += ll + ", "; - code = code.substring(0, code.length() - 2); - code += "])" + System.lineSeparator(); - code += "input_h = im.shape[0]" + System.lineSeparator(); - code += "input_w = im.shape[1]" + System.lineSeparator(); - code += "globals()['input_h'] = input_h" + System.lineSeparator(); - code += "globals()['input_w'] = input_w" + System.lineSeparator(); - code += "im = torch.from_numpy(np.transpose(im.astype('float32'), (2, 0, 1)))" + System.lineSeparator(); - code += "im_shm.unlink()" + System.lineSeparator(); - //code += "box_shm.close()" + System.lineSeparator(); - this.script += code; - } - - private & NativeType> - void sendCropAsNp() { - sendCropAsNp(null); - } - - private & NativeType> void sendCropAsNp(long[] cropSize) { - if (cropSize == null) - cropSize = new long[] {encodeCoords[3] - encodeCoords[1], encodeCoords[2] - encodeCoords[0], 3}; - //RandomAccessibleInterval crop = - // Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, interValSize ); - - //RandomAccessibleInterval crop = Views.offsetInterval(crop, new long[] {encodeCoords[1], encodeCoords[0], 0}, interValSize); - RandomAccessibleInterval crop = Views.offsetInterval(Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize); - targetDims = crop.dimensionsAsLongArray(); - shma = SharedMemoryArray.buildMemorySegmentForImage(new long[] {targetDims[0], targetDims[1], targetDims[2]}, - Util.getTypeFromInterval(crop)); - RealTypeConverters.copyFromTo(crop, shma.getSharedRAI()); - String code = ""; - // This line wants to recreate the original numpy array. Should look like: - // input0_appose_shm = shared_memory.SharedMemory(name=input0) - // input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64]) - code += "im_shm = shared_memory.SharedMemory(name='" - + shma.getNameForPython() + "', size=" + shma.getSize() - + ")" + System.lineSeparator(); - int size = 1; - for (long l : targetDims) {size *= l;} - code += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataType(Util.getTypeFromInterval(crop)) + "', buffer=im_shm.buf).reshape(["; - for (long ll : targetDims) - code += ll + ", "; - code = code.substring(0, code.length() - 2); - code += "])" + System.lineSeparator(); - code += "input_h = im.shape[0]" + System.lineSeparator(); - code += "input_w = im.shape[1]" + System.lineSeparator(); - //code += "np.save('/home/carlos/git/cropped.npy', im)" + System.lineSeparator(); - code += "globals()['input_h'] = input_h" + System.lineSeparator(); - code += "globals()['input_w'] = input_w" + System.lineSeparator(); - code += "im = torch.from_numpy(np.transpose(im.astype('float32'), (2, 0, 1)))" + System.lineSeparator(); - code += "im_shm.unlink()" + System.lineSeparator(); - //code += "box_shm.close()" + System.lineSeparator(); - this.script += code; - } - - private void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll) { + protected void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll) { String code = "" + System.lineSeparator() + "task.update('start predict')" + System.lineSeparator() + "input_points_list = []" + System.lineSeparator() @@ -1015,8 +305,9 @@ private void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll + "task.outputs['contours_y'] = contours_y" + System.lineSeparator(); this.script = code; } - - private void processBoxWithSAM(boolean returnAll) { + + @Override + protected void processBoxWithSAM(boolean returnAll) { String code = "" + System.lineSeparator() + "task.update('start predict')" + System.lineSeparator() + "input_box = np.array([[input_box[0], input_box[1]], [input_box[2], input_box[3]]])" + System.lineSeparator() @@ -1045,24 +336,14 @@ private void processBoxWithSAM(boolean returnAll) { this.script = code; } - private static & NativeType> - SharedMemoryArray createEfficientSAMInputSHM(final RandomAccessibleInterval inImg) { - long[] dims = inImg.dimensionsAsLongArray(); - if ((dims.length != 3 && dims.length != 2) || (dims.length == 3 && dims[2] != 3 && dims[2] != 1)){ - throw new IllegalArgumentException("Currently SAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, ...) 2D images." - + "The image dimensions order should be 'yxc', first dimension height, second width and third channels."); - } - return SharedMemoryArray.buildMemorySegmentForImage(new long[] {dims[0], dims[1], 3}, new FloatType()); - } - private & NativeType> void adaptImageToModel(final RandomAccessibleInterval ogImg, RandomAccessibleInterval targetImg) { if (ogImg.numDimensions() == 3 && ogImg.dimensionsAsLongArray()[2] == 3) { for (int i = 0; i < 3; i ++) - RealTypeConverters.copyFromTo( normalizedView(Views.hyperSlice(ogImg, 2, i)), Views.hyperSlice(targetImg, 2, i) ); + RealTypeConverters.copyFromTo( ImgLib2SAMUtils.normalizedView(Views.hyperSlice(ogImg, 2, i), this.debugPrinter), Views.hyperSlice(targetImg, 2, i) ); } else if (ogImg.numDimensions() == 3 && ogImg.dimensionsAsLongArray()[2] == 1) { debugPrinter.printText("CONVERTED 1 CHANNEL IMAGE INTO 3 TO BE FEEDED TO SAMJ"); - IntervalView resIm = Views.interval( Views.expandMirrorDouble(normalizedView(ogImg), new long[] {0, 0, 2}), + IntervalView resIm = Views.interval( Views.expandMirrorDouble(ImgLib2SAMUtils.normalizedView(ogImg, this.debugPrinter), new long[] {0, 0, 2}), Intervals.createMinMax(new long[] {0, 0, 0, ogImg.dimensionsAsLongArray()[0], ogImg.dimensionsAsLongArray()[1], 2}) ); RealTypeConverters.copyFromTo( resIm, targetImg ); } else if (ogImg.numDimensions() == 2) { @@ -1071,8 +352,6 @@ void adaptImageToModel(final RandomAccessibleInterval ogImg, RandomAccessible throw new IllegalArgumentException("Currently SAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, ...) 2D images." + "The image dimensions order should be 'yxc', first dimension height, second width and third channels."); } - this.img = targetImg; - this.targetDims = targetImg.dimensionsAsLongArray(); } /** @@ -1090,4 +369,28 @@ public static void main(String[] args) throws IOException, RuntimeException, Int sam.processBox(new int[] {0, 5, 10, 26}); } } + + @Override + protected & NativeType> void setImageOfInterest(RandomAccessibleInterval rai) { + checkImageIsFine(rai); + long[] dims = rai.dimensionsAsLongArray(); + this.img = Views.interval(rai, new long[] {0, 0, 0}, new long[] {dims[0] - 1, dims[1] - 1, 2}); + this.targetDims = img.dimensionsAsLongArray(); + } + + private & NativeType> void checkImageIsFine(RandomAccessibleInterval inImg) { + long[] dims = inImg.dimensionsAsLongArray(); + if ((dims.length != 3 && dims.length != 2) || (dims.length == 3 && dims[2] != 3 && dims[2] != 1)){ + throw new IllegalArgumentException("Currently EfficientSAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, ...) 2D images." + + "The image dimensions order should be 'xyc', first dimension width, second height and third channels."); + } + } + + @Override + protected & NativeType> void createSHMArray(RandomAccessibleInterval imShared) { + RandomAccessibleInterval imageToBeSent = ImgLib2SAMUtils.reescaleIfNeeded(imShared); + long[] dims = imageToBeSent.dimensionsAsLongArray(); + shma = SharedMemoryArray.buildMemorySegmentForImage(new long[] {dims[0], dims[1], dims[2]}, new FloatType()); + adaptImageToModel(imageToBeSent, shma.getSharedRAI()); + } }