Skip to content

Commit

Permalink
correct small errors
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 4, 2024
1 parent 2ce6b80 commit 82c9623
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
9 changes: 6 additions & 3 deletions src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class EfficientSamJ extends AbstractSamJ {
+ "from skimage import measure" + System.lineSeparator()
+ "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator()
+ "import torch" + System.lineSeparator()
+ "from scipy.ndimage import label" + System.lineSeparator()
+ "from scipy.ndimage import binary_fill_holes" + System.lineSeparator()
+ "import sys" + System.lineSeparator()
+ "sys.path.append(r'%s')" + System.lineSeparator()
Expand All @@ -80,6 +81,7 @@ public class EfficientSamJ extends AbstractSamJ {
+ "globals()['measure'] = measure" + System.lineSeparator()
+ "globals()['np'] = np" + System.lineSeparator()
+ "globals()['torch'] = torch" + System.lineSeparator()
+ "globals()['label'] = label" + System.lineSeparator()
+ "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator()
+ "globals()['predictor'] = predictor" + System.lineSeparator();

Expand Down Expand Up @@ -413,6 +415,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
code += l + ",";
code += "])" + System.lineSeparator();
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
code += "num_features -= 1" + System.lineSeparator();
}
code += ""
+ "contours_x = []" + System.lineSeparator()
Expand All @@ -426,7 +429,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
+ "with ThreadPoolExecutor(max_workers=num_threads) as executor:" + System.lineSeparator()
+ " futures = []" + System.lineSeparator()
+ " n_objects = 0" + System.lineSeparator()
+ " for n_feat in range(num_features):" + System.lineSeparator()
+ " for n_feat in range(1, num_features + 1):" + System.lineSeparator()
+ " extracted_point_prompts = []" + System.lineSeparator()
+ " extracted_point_labels = []" + System.lineSeparator()
+ " inds = np.where(labeled_array == n_feat)" + System.lineSeparator()
Expand All @@ -435,7 +438,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
+ " for pp in range(n_points):" + System.lineSeparator()
+ " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
+ " extracted_point_labels += [n_feat]" + System.lineSeparator()
+ " ip = torch.reshape(torch.tensor(np.array(extracted_point_prompts).reshape(1, 2)), [1, 1, -1, 2])" + System.lineSeparator()
+ " ip = torch.reshape(torch.tensor(np.array(extracted_point_prompts).reshape(len(extracted_point_prompts), 2)), [1, 1, -1, 2])" + System.lineSeparator()
+ " il = torch.reshape(torch.tensor(np.array(extracted_point_labels)), [1, 1, -1])" + System.lineSeparator()
+ " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator()
+ " ip," + System.lineSeparator()
Expand Down Expand Up @@ -530,7 +533,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
+ "task.outputs['contours_x'] = contours_x" + System.lineSeparator()
+ "task.outputs['contours_y'] = contours_y" + System.lineSeparator()
+ "task.outputs['rle'] = rle_masks" + System.lineSeparator();
code += "mask = 0" + System.lineSeparator();
code += "mask_batch = None" + System.lineSeparator();
if (shmArr != null) {
code += "shm_mask.close()" + System.lineSeparator();
code += "shm_mask.unlink()" + System.lineSeparator();
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/ai/nets/samj/models/EfficientViTSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
code += l + ",";
code += "])" + System.lineSeparator();
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
code += "num_features -= 1" + System.lineSeparator();
}
code += ""
+ "contours_x = []" + System.lineSeparator()
Expand All @@ -509,7 +510,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
+ "with ThreadPoolExecutor(max_workers=num_threads) as executor:" + System.lineSeparator()
+ " futures = []" + System.lineSeparator()
+ " n_objects = 0" + System.lineSeparator()
+ " for n_feat in range(num_features):" + System.lineSeparator()
+ " for n_feat in range(1, num_features + 1):" + System.lineSeparator()
+ " extracted_point_prompts = []" + System.lineSeparator()
+ " extracted_point_labels = []" + System.lineSeparator()
+ " inds = np.where(labeled_array == n_feat)" + System.lineSeparator()
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu
code += l + ",";
code += "])" + System.lineSeparator();
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
code += "num_features -= 1" + System.lineSeparator();
}
code += ""
+ "contours_x = []" + System.lineSeparator()
Expand All @@ -483,7 +484,7 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu
+ "with ThreadPoolExecutor(max_workers=num_threads) as executor:" + System.lineSeparator()
+ " futures = []" + System.lineSeparator()
+ " n_objects = 0" + System.lineSeparator()
+ " for n_feat in range(num_features):" + System.lineSeparator()
+ " for n_feat in range(1, num_features + 1):" + System.lineSeparator()
+ " extracted_point_prompts = []" + System.lineSeparator()
+ " extracted_point_labels = []" + System.lineSeparator()
+ " inds = np.where(labeled_array == n_feat)" + System.lineSeparator()
Expand Down

0 comments on commit 82c9623

Please sign in to comment.