diff --git a/src/main/java/ai/nets/samj/EfficientSamJ.java b/src/main/java/ai/nets/samj/EfficientSamJ.java index cb5eb7a..26b0254 100644 --- a/src/main/java/ai/nets/samj/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/EfficientSamJ.java @@ -102,6 +102,8 @@ public class EfficientSamJ extends AbstractSamJ implements AutoCloseable { * paths and names */ private String IMPORTS_FORMATED; + + private static int REENCODE_THRESH = 50; /** * Create an instance of the class to be able to run EfficientSAM in Java. @@ -574,6 +576,7 @@ public List processBox(int[] boundingBox) */ public List processBox(int[] boundingBox, boolean returnAll) throws IOException, RuntimeException, InterruptedException { + needsMoreResolution(boundingBox); this.script = ""; processBoxWithSAM(returnAll); HashMap inputs = new HashMap(); @@ -584,6 +587,31 @@ public List processBox(int[] boundingBox, boolean returnAll) return polys; } + /** + * For bounding box masks, check whether the its size is too small compared to the size + * of the encoded image. + * + * Approximately, if the original image encoded is about 20 times bigger than the bounding box size, + * the resolution of the SAM-based model encodings will not be enough to identify the object of interest, + * thus re-encoding of a zoomed part of the image will be necessary. + * + * @param boundingBox + * bounding box of interest + * @return whether the bounding box of interest is big enough to produce good results or not + */ + public boolean needsMoreResolution(int[] boundingBox) { + long xSize = boundingBox[2] - boundingBox[0]; + long ySize = boundingBox[3] - boundingBox[1]; + long encodedX = targetDims[1]; + long encodedY = targetDims[0]; + if (xSize * REENCODE_THRESH < encodedX && ySize * REENCODE_THRESH < encodedY) + return true; + return false; + } + + public boolean checkBoundingBox() { + return false; + } @Override /** @@ -682,7 +710,7 @@ private void processBoxWithSAM(boolean returnAll) { + "mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() + "task.update('end predict')" + System.lineSeparator() + "task.update(str(mask.shape))" + System.lineSeparator() - //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() + //+ "np.save('/home/carlos/git/mask.npy', mask)" + System.lineSeparator() + "contours_x,contours_y = get_polygons_from_binary_mask(mask, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + "task.update('all contours traced')" + System.lineSeparator() + "task.outputs['contours_x'] = contours_x" + System.lineSeparator()