diff --git a/src/main/java/ai/nets/samj/gui/MainGUI.java b/src/main/java/ai/nets/samj/gui/MainGUI.java index 28b0d0c..5d2e0ea 100644 --- a/src/main/java/ai/nets/samj/gui/MainGUI.java +++ b/src/main/java/ai/nets/samj/gui/MainGUI.java @@ -436,7 +436,9 @@ private void toggleDrawer() { } private < T extends RealType< T > & NativeType< T > > void batchSAMize() throws IOException, RuntimeException, InterruptedException { - RandomAccessibleInterval rai = this.consumer.getFocusedImageAsRai(); + RandomAccessibleInterval rai = null; + if (this.consumer.getFocusedImage() != this.cmbImages.getSelectedObject()) + rai = this.consumer.getFocusedImageAsRai(); List pointPrompts = this.consumer.getPointRoisOnFocusImage(); List rectPrompts = this.consumer.getRectRoisOnFocusImage(); if (pointPrompts.size() == 0 && rectPrompts.size() == 0 && !(rai.getType() instanceof IntegerType)){ diff --git a/src/main/java/ai/nets/samj/models/AbstractSamJ.java b/src/main/java/ai/nets/samj/models/AbstractSamJ.java index 64b1abf..71c4749 100644 --- a/src/main/java/ai/nets/samj/models/AbstractSamJ.java +++ b/src/main/java/ai/nets/samj/models/AbstractSamJ.java @@ -167,8 +167,7 @@ public interface DebugTextPrinter { void printText(String text); } protected abstract void cellSAM(List grid, boolean returnAll); - protected abstract void - processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, boolean returnAll); + protected abstract void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean returnAll); protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll); @@ -436,13 +435,14 @@ List processBatchOfPrompts(List pointsList, List rects, rectPrompts = rects.stream().map(rr -> new int[] {rr.x, rr.y, rr.x + rr.width, rr.y + rr.height}) .collect(Collectors.toList()); inputs.put("rect_prompts", rectPrompts); - processPromptsBatchWithSAM(pointsList.size(), rects.size(), maskShma, returnAll); + processPromptsBatchWithSAM(maskShma, returnAll); printScript(script, "Batch of prompts inference"); - List polys = processAndRetrieveContours(null); + List polys = processAndRetrieveContours(inputs); recalculatePolys(polys, encodeCoords); return polys; } catch (IOException | RuntimeException | InterruptedException ex) { - maskShma.close(); + if (maskShma != null) + maskShma.close(); throw ex; } } diff --git a/src/main/java/ai/nets/samj/models/EfficientSamJ.java b/src/main/java/ai/nets/samj/models/EfficientSamJ.java index 110b2c3..c820e42 100644 --- a/src/main/java/ai/nets/samj/models/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientSamJ.java @@ -19,7 +19,6 @@ */ package ai.nets.samj.models; -import java.awt.Rectangle; import java.io.File; import java.io.IOException; import java.util.ArrayList; @@ -387,7 +386,7 @@ public String deleteEncodingScript(String encodingName) { } @Override - protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, + protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean returnAll) { String code = ""; code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index bf30b86..09862cd 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -27,7 +27,6 @@ import ai.nets.samj.install.EfficientViTSamEnvManager; import ai.nets.samj.install.SamEnvManagerAbstract; -import java.awt.Rectangle; import java.io.File; import java.io.IOException; @@ -462,7 +461,7 @@ public String deleteEncodingScript(String encodingName) { } @Override - protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, + protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean returnAll) { String code = ""; code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); diff --git a/src/main/java/ai/nets/samj/models/Sam2.java b/src/main/java/ai/nets/samj/models/Sam2.java index d343deb..dcc1d8e 100644 --- a/src/main/java/ai/nets/samj/models/Sam2.java +++ b/src/main/java/ai/nets/samj/models/Sam2.java @@ -446,19 +446,23 @@ public String deleteEncodingScript(String encodingName) { } @Override - protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, boolean returnAll) { - String code = ""; - code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); - code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; - for (long l : shmArr.getOriginalShape()) - code += l + ","; - code += "])" + System.lineSeparator(); + protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean returnAll) { + String code = "" + + "num_features = 0" + System.lineSeparator(); + if (shmArr != null) { + code += "" + + "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator() + + "mask_batch = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; + for (long l : shmArr.getOriginalShape()) + code += l + ","; + code += "])" + System.lineSeparator(); + code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator(); + } code += "" + "point_prompts = []" + System.lineSeparator() + "point_labels = []" + System.lineSeparator() + "mask_prompts = []" + System.lineSeparator() + "mask_labels = []" + System.lineSeparator() - + "labeled_array, num_features = label(mask)" + System.lineSeparator() + "contours_x = []" + System.lineSeparator() + "contours_y = []" + System.lineSeparator() + "rle_masks = []" + System.lineSeparator() @@ -483,7 +487,7 @@ protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryA + " contours_y += c_Y" + System.lineSeparator() + " rle_masks += r_m" + System.lineSeparator() + "" + System.lineSeparator() - + "for p_prompt in range(" + npoints + "):" + System.lineSeparator() + + "for p_prompt in point_prompts:" + System.lineSeparator() + " mask, _, _ = predictor.predict(" + System.lineSeparator() + " point_coords=np.array(p_prompt).reshape(1, 2)," + System.lineSeparator() + " point_labels=np.array([1])," + System.lineSeparator() @@ -496,7 +500,7 @@ protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryA + " rle_masks += r_m" + System.lineSeparator() + "" + System.lineSeparator() + "" + System.lineSeparator() - + "for rect_prompt in range(" + nrects + "):" + System.lineSeparator() + + "for rect_prompt in rect_prompts:" + System.lineSeparator() + " input_box = np.array([[input_box[0], input_box[1]], [input_box[2], input_box[3]]])" + System.lineSeparator() + " mask, _, _ = predictor.predict(" + System.lineSeparator() + " point_coords=None," + System.lineSeparator() @@ -519,8 +523,10 @@ protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryA + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); code += "mask = 0" + System.lineSeparator(); - code += "shm_mask.close()" + System.lineSeparator(); - code += "shm_mask.unlink()" + System.lineSeparator(); + if (shmArr != null) { + code += "shm_mask.close()" + System.lineSeparator(); + code += "shm_mask.unlink()" + System.lineSeparator(); + } this.script = code; } }