From 21b3ad15721bd9d84109ccc679abdb4934d44056 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Tue, 10 Dec 2024 02:12:14 +0100 Subject: [PATCH] correct two small bugs --- src/main/java/ai/nets/samj/models/EfficientSamJ.java | 3 +-- src/main/java/ai/nets/samj/models/EfficientViTSamJ.java | 3 +-- src/main/java/ai/nets/samj/models/Sam2.java | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/main/java/ai/nets/samj/models/EfficientSamJ.java b/src/main/java/ai/nets/samj/models/EfficientSamJ.java index dbd672f..e4fff64 100644 --- a/src/main/java/ai/nets/samj/models/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientSamJ.java @@ -420,7 +420,6 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, code = String.format(code, size); code += "])" + System.lineSeparator(); code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator(); - code += "num_features -= 1" + System.lineSeparator(); } code += "" + "contours_x = []" + System.lineSeparator() @@ -442,7 +441,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, + " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator() + " for pp in range(n_points):" + System.lineSeparator() + " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator() - + " extracted_point_labels += [n_feat]" + System.lineSeparator() + + " extracted_point_labels += [1]" + System.lineSeparator() + " ip = torch.reshape(torch.tensor(np.array(extracted_point_prompts).reshape(len(extracted_point_prompts), 2)), [1, 1, -1, 2])" + System.lineSeparator() + " il = torch.reshape(torch.tensor(np.array(extracted_point_labels)), [1, 1, -1])" + System.lineSeparator() + " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index 881e4bd..f0f31b7 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -501,7 +501,6 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, code = String.format(code, size); code += "])" + System.lineSeparator(); code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator(); - code += "num_features -= 1" + System.lineSeparator(); } code += "" + "contours_x = []" + System.lineSeparator() @@ -523,7 +522,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, + " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator() + " for pp in range(n_points):" + System.lineSeparator() + " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator() - + " extracted_point_labels += [n_feat]" + System.lineSeparator() + + " extracted_point_labels += [1]" + System.lineSeparator() + " mask, _, _ = predictor.predict(" + System.lineSeparator() + " point_coords=np.array(extracted_point_prompts)," + System.lineSeparator() + " point_labels=np.array(extracted_point_labels)," + System.lineSeparator() diff --git a/src/main/java/ai/nets/samj/models/Sam2.java b/src/main/java/ai/nets/samj/models/Sam2.java index 520241f..2a53070 100644 --- a/src/main/java/ai/nets/samj/models/Sam2.java +++ b/src/main/java/ai/nets/samj/models/Sam2.java @@ -475,7 +475,6 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu code = String.format(code, size); code += "])" + System.lineSeparator(); code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator(); - code += "num_features -= 1" + System.lineSeparator(); } code += "" + "contours_x = []" + System.lineSeparator() @@ -497,7 +496,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu + " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator() + " for pp in range(n_points):" + System.lineSeparator() + " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator() - + " extracted_point_labels += [n_feat]" + System.lineSeparator() + + " extracted_point_labels += [1]" + System.lineSeparator() + " mask, _, _ = predictor.predict(" + System.lineSeparator() + " point_coords=np.array(extracted_point_prompts)," + System.lineSeparator() + " point_labels=np.array(extracted_point_labels)," + System.lineSeparator()