Skip to content

Commit

Permalink
BATCH OF POINT PROMPTS WORKS
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 30, 2024
1 parent 1ae2e84 commit c27696f
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import ai.nets.samj.install.Sam2EnvManager;
import ai.nets.samj.install.SamEnvManagerAbstract;

import java.awt.Rectangle;
import java.io.IOException;

import io.bioimage.modelrunner.apposed.appose.Environment;
Expand Down Expand Up @@ -459,32 +458,30 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
}
code += ""
+ "point_prompts = []" + System.lineSeparator()
+ "point_labels = []" + System.lineSeparator()
+ "mask_prompts = []" + System.lineSeparator()
+ "mask_labels = []" + System.lineSeparator()
+ "contours_x = []" + System.lineSeparator()
+ "contours_y = []" + System.lineSeparator()
+ "rle_masks = []" + System.lineSeparator()
// TODO right now is geetting the mask after each prompt
// TODO test processing first every prompt and then getting the masks
+ "" + System.lineSeparator()
+ "for n_feat in range(num_features):" + System.lineSeparator()
+ " extracted_point_prompts = []" + System.lineSeparator()
+ " extracted_point_labels = []" + System.lineSeparator()
+ " inds = np.where(labeled_array == n_feat)" + System.lineSeparator()
+ " n_points = np.min([3, inds[0].shape[0]])" + System.lineSeparator()
+ " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator()
+ " for pp in range(n_points):" + System.lineSeparator()
+ " point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
+ " point_labels += [n_feat]" + System.lineSeparator()
+ " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
+ " extracted_point_labels += [n_feat]" + System.lineSeparator()
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=point_prompts," + System.lineSeparator()
+ " point_labels=point_labels," + System.lineSeparator()
+ " multimask_output=False," + System.lineSeparator()
+ " box=None,)" + System.lineSeparator()
+ (this.isIJROIManager ? " mask += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " contours_x += c_x" + System.lineSeparator()
+ " contours_y += c_Y" + System.lineSeparator()
+ " contours_y += c_y" + System.lineSeparator()
+ " rle_masks += r_m" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "for p_prompt in point_prompts:" + System.lineSeparator()
Expand All @@ -493,10 +490,10 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu
+ " point_labels=np.array([1])," + System.lineSeparator()
+ " multimask_output=False," + System.lineSeparator()
+ " box=None,)" + System.lineSeparator()
+ (this.isIJROIManager ? " mask += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " contours_x += c_x" + System.lineSeparator()
+ " contours_y += c_Y" + System.lineSeparator()
+ " contours_y += c_y" + System.lineSeparator()
+ " rle_masks += r_m" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
Expand All @@ -507,10 +504,10 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu
+ " point_labels=np.array([1])," + System.lineSeparator()
+ " multimask_output=False," + System.lineSeparator()
+ " box=input_box,)" + System.lineSeparator()
+ (this.isIJROIManager ? " mask += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " contours_x += c_x" + System.lineSeparator()
+ " contours_y += c_Y" + System.lineSeparator()
+ " contours_y += c_y" + System.lineSeparator()
+ " rle_masks += r_m" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
Expand Down

0 comments on commit c27696f

Please sign in to comment.