From 2ce6b80e997c1e5226b958eacbe7da336407f551 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Wed, 4 Dec 2024 01:16:25 +0100 Subject: [PATCH] work with masks --- src/main/java/ai/nets/samj/models/EfficientViTSamJ.java | 8 +++++--- src/main/java/ai/nets/samj/models/Sam2.java | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index 7a75eab..699f492 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -79,6 +79,7 @@ public class EfficientViTSamJ extends AbstractSamJ { + "from skimage import measure" + System.lineSeparator() + "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator() + "import torch" + System.lineSeparator() + + "from scipy.ndimage import label" + System.lineSeparator() + "from scipy.ndimage import binary_fill_holes" + System.lineSeparator() + "import sys" + System.lineSeparator() + "import os" + System.lineSeparator() @@ -108,6 +109,7 @@ public class EfficientViTSamJ extends AbstractSamJ { + "globals()['measure'] = measure" + System.lineSeparator() + "globals()['np'] = np" + System.lineSeparator() + "globals()['torch'] = torch" + System.lineSeparator() + + "globals()['label'] = label" + System.lineSeparator() + "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator() + "globals()['predictor'] = predictor" + System.lineSeparator(); /** @@ -517,8 +519,8 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, + " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator() + " extracted_point_labels += [n_feat]" + System.lineSeparator() + " mask, _, _ = predictor.predict(" + System.lineSeparator() - + " point_coords=point_prompts," + System.lineSeparator() - + " point_labels=point_labels," + System.lineSeparator() + + " point_coords=np.array(extracted_point_prompts)," + System.lineSeparator() + + " point_labels=np.array(extracted_point_labels)," + System.lineSeparator() + " multimask_output=False," + System.lineSeparator() + " box=None,)" + System.lineSeparator() + (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator() @@ -583,7 +585,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); - code += "mask = 0" + System.lineSeparator(); + code += "mask_batch = None" + System.lineSeparator(); if (shmArr != null) { code += "shm_mask.close()" + System.lineSeparator(); code += "shm_mask.unlink()" + System.lineSeparator(); diff --git a/src/main/java/ai/nets/samj/models/Sam2.java b/src/main/java/ai/nets/samj/models/Sam2.java index 6585da6..a7bfc3f 100644 --- a/src/main/java/ai/nets/samj/models/Sam2.java +++ b/src/main/java/ai/nets/samj/models/Sam2.java @@ -493,8 +493,8 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu + " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator() + " extracted_point_labels += [n_feat]" + System.lineSeparator() + " mask, _, _ = predictor.predict(" + System.lineSeparator() - + " point_coords=point_prompts," + System.lineSeparator() - + " point_labels=point_labels," + System.lineSeparator() + + " point_coords=np.array(extracted_point_prompts)," + System.lineSeparator() + + " point_labels=np.array(extracted_point_labels)," + System.lineSeparator() + " multimask_output=False," + System.lineSeparator() + " box=None,)" + System.lineSeparator() + (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator() @@ -559,7 +559,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); - code += "mask = 0" + System.lineSeparator(); + code += "mask_batch = None" + System.lineSeparator(); if (shmArr != null) { code += "shm_mask.close()" + System.lineSeparator(); code += "shm_mask.unlink()" + System.lineSeparator();