diff --git a/src/main/java/ai/nets/samj/models/AbstractSamJ.java b/src/main/java/ai/nets/samj/models/AbstractSamJ.java index db612e2..64b1abf 100644 --- a/src/main/java/ai/nets/samj/models/AbstractSamJ.java +++ b/src/main/java/ai/nets/samj/models/AbstractSamJ.java @@ -168,7 +168,7 @@ public interface DebugTextPrinter { void printText(String text); } protected abstract void cellSAM(List grid, boolean returnAll); protected abstract void - processPromptsBatchWithSAM(List points, List rects, SharedMemoryArray shmArr, boolean returnAll); + processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, boolean returnAll); protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll); @@ -429,7 +429,14 @@ List processBatchOfPrompts(List pointsList, List rects, maskShma = SharedMemoryArray.createSHMAFromRAI(rai, false, false); try { - processPromptsBatchWithSAM(pointsList, rects, maskShma, returnAll); + HashMap inputs = new HashMap(); + inputs.put("point_prompts", pointsList == null ? new ArrayList() : pointsList); + List rectPrompts = new ArrayList(); + if (rects != null && rects.size() > 0) + rectPrompts = rects.stream().map(rr -> new int[] {rr.x, rr.y, rr.x + rr.width, rr.y + rr.height}) + .collect(Collectors.toList()); + inputs.put("rect_prompts", rectPrompts); + processPromptsBatchWithSAM(pointsList.size(), rects.size(), maskShma, returnAll); printScript(script, "Batch of prompts inference"); List polys = processAndRetrieveContours(null); recalculatePolys(polys, encodeCoords); diff --git a/src/main/java/ai/nets/samj/models/EfficientSamJ.java b/src/main/java/ai/nets/samj/models/EfficientSamJ.java index 18a2ecf..110b2c3 100644 --- a/src/main/java/ai/nets/samj/models/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientSamJ.java @@ -387,7 +387,7 @@ public String deleteEncodingScript(String encodingName) { } @Override - protected void processPromptsBatchWithSAM(List points, List rects, SharedMemoryArray shmArr, + protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, boolean returnAll) { String code = ""; code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index 0c63fcd..bf30b86 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -462,7 +462,7 @@ public String deleteEncodingScript(String encodingName) { } @Override - protected void processPromptsBatchWithSAM(List points, List rects, SharedMemoryArray shmArr, + protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, boolean returnAll) { String code = ""; code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); diff --git a/src/main/java/ai/nets/samj/models/Sam2.java b/src/main/java/ai/nets/samj/models/Sam2.java index ecdcd4a..d343deb 100644 --- a/src/main/java/ai/nets/samj/models/Sam2.java +++ b/src/main/java/ai/nets/samj/models/Sam2.java @@ -446,8 +446,7 @@ public String deleteEncodingScript(String encodingName) { } @Override - protected void processPromptsBatchWithSAM(List points, - List rects, SharedMemoryArray shmArr, boolean returnAll) { + protected void processPromptsBatchWithSAM(int npoints, int nrects, SharedMemoryArray shmArr, boolean returnAll) { String code = ""; code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; @@ -484,7 +483,7 @@ protected void processPromptsBatchWithSAM(List points, + " contours_y += c_Y" + System.lineSeparator() + " rle_masks += r_m" + System.lineSeparator() + "" + System.lineSeparator() - + "for p_prompt in range(len(point_prompts)):" + System.lineSeparator() + + "for p_prompt in range(" + npoints + "):" + System.lineSeparator() + " mask, _, _ = predictor.predict(" + System.lineSeparator() + " point_coords=np.array(p_prompt).reshape(1, 2)," + System.lineSeparator() + " point_labels=np.array([1])," + System.lineSeparator() @@ -497,7 +496,7 @@ protected void processPromptsBatchWithSAM(List points, + " rle_masks += r_m" + System.lineSeparator() + "" + System.lineSeparator() + "" + System.lineSeparator() - + "for rect_prompt in range(len(rect_prompts)):" + System.lineSeparator() + + "for rect_prompt in range(" + nrects + "):" + System.lineSeparator() + " input_box = np.array([[input_box[0], input_box[1]], [input_box[2], input_box[3]]])" + System.lineSeparator() + " mask, _, _ = predictor.predict(" + System.lineSeparator() + " point_coords=None," + System.lineSeparator()