Skip to content

Commit

Permalink
add batchsamize to efficientvitsam
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 3, 2024
1 parent 4c8250c commit d7901c4
Showing 1 changed file with 124 additions and 42 deletions.
166 changes: 124 additions & 42 deletions src/main/java/ai/nets/samj/models/EfficientViTSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -463,49 +463,131 @@ public String deleteEncodingScript(String encodingName) {
@Override
protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr,
boolean returnAll) {
String code = "";
code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator();
code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape([";
for (long l : shmArr.getOriginalShape())
code += l + ",";
code += "])" + System.lineSeparator();
code += "different_mask_vals = np.unique(mask)" + System.lineSeparator();
code += "contours_x = []" + System.lineSeparator();
code += "contours_y = []" + System.lineSeparator();
code += "rle_masks = []" + System.lineSeparator();
code += "for val in different_mask_vals:" + System.lineSeparator()
+ " if val < 1:" + System.lineSeparator()
+ " continue" + System.lineSeparator()
+ " locations = np.where(mask == val)" + System.lineSeparator()
+ " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator()
+ " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator()
+ " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator()
+ " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator()
+ " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator()
+ " input_points_pos[:, 0] = locations[0]" + System.lineSeparator()
+ " input_points_pos[:, 1] = locations[1]" + System.lineSeparator()
+ " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator()
+ " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator()
+ " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator()
+ " input_label = np.concatenate((input_labels_pos, input_labels_neg), axis=0)" + System.lineSeparator()
+ " mask_val, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=input_points," + System.lineSeparator()
+ " point_labels=input_label," + System.lineSeparator()
+ " multimask_output=False," + System.lineSeparator()
+ " box=None,)" + System.lineSeparator()
//+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator()
+ (this.isIJROIManager ? "mask_val[0, 1:, 1:] += mask_val[0, :-1, :-1]" : "") + System.lineSeparator()
+ " contours_x_val,contours_y_val,rle_val = get_polygons_from_binary_mask(mask_val[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " contours_x += contours_x_val" + System.lineSeparator()
+ " contours_y += contours_y_val" + System.lineSeparator()
+ " rle_masks += rle_val" + System.lineSeparator()
+ "task.update('all contours traced')" + System.lineSeparator()
+ "task.outputs['contours_x'] = contours_x" + System.lineSeparator()
+ "task.outputs['contours_y'] = contours_y" + System.lineSeparator()
+ "task.outputs['rle'] = rle_masks" + System.lineSeparator();
String code = ""
+ "num_threads = 3" + System.lineSeparator()
+ "finished_threads = []" + System.lineSeparator()
+ "print(('threading' not in globals().keys()))" + System.lineSeparator()
+ "if \"threading\" not in globals().keys():" + System.lineSeparator()
+ " import threading" + System.lineSeparator()
+ " globals()['threading'] = threading" + System.lineSeparator()
+ "if \"ThreadPoolExecutor\" not in globals().keys():" + System.lineSeparator()
+ " from concurrent.futures import ThreadPoolExecutor, as_completed" + System.lineSeparator()
+ " globals()['ThreadPoolExecutor'] = ThreadPoolExecutor" + System.lineSeparator()
+ " globals()['as_completed'] = as_completed" + System.lineSeparator()
+ "lock = threading.Lock()" + System.lineSeparator()
+ "def respond_in_thread(task, args, inds, lock, finished_threads):" + System.lineSeparator()
+ " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator()
+ " with lock:" + System.lineSeparator()
+ " finished_threads.extend(inds)" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "def cancel_unstarted_tasks(futures):" + System.lineSeparator()
+ " for future in futures:" + System.lineSeparator()
+ " if not future.running() and not future.done():" + System.lineSeparator()
+ " future.cancel()" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "num_features = 0" + System.lineSeparator();
if (shmArr != null) {
code += ""
+ "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator()
+ "mask_batch = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape([";
for (long l : shmArr.getOriginalShape())
code += l + ",";
code += "])" + System.lineSeparator();
code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator();
}
code += ""
+ "contours_x = []" + System.lineSeparator()
+ "contours_y = []" + System.lineSeparator()
+ "rle_masks = []" + System.lineSeparator()
+ "ntot = num_features + len(point_prompts) + len(rect_prompts)" + System.lineSeparator()
+ "args = {\"outputs\": {'n': str(ntot)}, \"message\": '" + AbstractSamJ.UPDATE_ID_N_CONTOURS + "'}" + System.lineSeparator()
+ "task._respond(ResponseType.UPDATE, args)" + System.lineSeparator()
// TODO right now is geetting the mask after each prompt
// TODO test processing first every prompt and then getting the masks
+ "with ThreadPoolExecutor(max_workers=num_threads) as executor:" + System.lineSeparator()
+ " futures = []" + System.lineSeparator()
+ " n_objects = 0" + System.lineSeparator()
+ " for n_feat in range(num_features):" + System.lineSeparator()
+ " extracted_point_prompts = []" + System.lineSeparator()
+ " extracted_point_labels = []" + System.lineSeparator()
+ " inds = np.where(labeled_array == n_feat)" + System.lineSeparator()
+ " n_points = np.min([3, inds[0].shape[0]])" + System.lineSeparator()
+ " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator()
+ " for pp in range(n_points):" + System.lineSeparator()
+ " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
+ " extracted_point_labels += [n_feat]" + System.lineSeparator()
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=point_prompts," + System.lineSeparator()
+ " point_labels=point_labels," + System.lineSeparator()
+ " multimask_output=False," + System.lineSeparator()
+ " box=None,)" + System.lineSeparator()
+ (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " contours_x += c_x" + System.lineSeparator()
+ " contours_y += c_y" + System.lineSeparator()
+ " rle_masks += r_m" + System.lineSeparator()
+ " args = {\"outputs\": {'temp_x': c_x, 'temp_y': c_y, 'temp_mask': r_m}, \"message\": '" + AbstractSamJ.UPDATE_ID_CONTOUR + "'}" + System.lineSeparator()
+ " it_list = list(range(n_objects, n_objects := n_objects + len(r_m)))" + System.lineSeparator()
+ " future = executor.submit(respond_in_thread, task, args, it_list, lock, finished_threads)" + System.lineSeparator()
+ " futures.append(future)" + System.lineSeparator()
// TODO + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator()
+ "" + System.lineSeparator()
+ " for p_prompt in point_prompts:" + System.lineSeparator()
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=np.array(p_prompt).reshape(1, 2)," + System.lineSeparator()
+ " point_labels=np.array([1])," + System.lineSeparator()
+ " multimask_output=False," + System.lineSeparator()
+ " box=None,)" + System.lineSeparator()
+ (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " contours_x += c_x" + System.lineSeparator()
+ " contours_y += c_y" + System.lineSeparator()
+ " rle_masks += r_m" + System.lineSeparator()
+ " args = {\"outputs\": {'point': p_prompt, 'temp_x': c_x, 'temp_y': c_y, 'temp_mask': r_m}, \"message\": '" + AbstractSamJ.UPDATE_ID_CONTOUR + "'}" + System.lineSeparator()
+ " it_list = list(range(n_objects, n_objects := n_objects + len(r_m)))" + System.lineSeparator()
+ " future = executor.submit(respond_in_thread, task, args, it_list, lock, finished_threads)" + System.lineSeparator()
+ " futures.append(future)" + System.lineSeparator()
// TODO + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
+ " for rect_prompt in rect_prompts:" + System.lineSeparator()
+ " input_box = np.array([[rect_prompt[0], rect_prompt[1]], [rect_prompt[2], rect_prompt[3]]])" + System.lineSeparator()
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=None," + System.lineSeparator()
+ " point_labels=np.array([1])," + System.lineSeparator()
+ " multimask_output=False," + System.lineSeparator()
+ " box=input_box,)" + System.lineSeparator()
+ (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " contours_x += c_x" + System.lineSeparator()
+ " contours_y += c_y" + System.lineSeparator()
+ " rle_masks += r_m" + System.lineSeparator()
+ " args = {\"outputs\": {'rect': rect_prompt, 'temp_x': c_x, 'temp_y': c_y, 'temp_mask': r_m}, \"message\": '" + AbstractSamJ.UPDATE_ID_CONTOUR + "'}" + System.lineSeparator()
+ " it_list = list(range(n_objects, n_objects := n_objects + len(r_m)))" + System.lineSeparator()
+ " future = executor.submit(respond_in_thread, task, args, it_list, lock, finished_threads)" + System.lineSeparator()
+ " futures.append(future)" + System.lineSeparator()
// TODO + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
+ " finished_threads.sort()" + System.lineSeparator()
+ " cancel_unstarted_tasks(futures)" + System.lineSeparator()
+ " for i, future in enumerate(futures[::-1]):" + System.lineSeparator()
+ " if not future.cancelled():" + System.lineSeparator()
+ " future.result()" + System.lineSeparator()
+ " for i in finished_threads[::-1]:" + System.lineSeparator()
+ " contours_x.pop(i)" + System.lineSeparator()
+ " contours_y.pop(i)" + System.lineSeparator()
+ " rle_masks.pop(i)" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "task.update('all contours traced')" + System.lineSeparator()
+ "task.outputs['contours_x'] = contours_x" + System.lineSeparator()
+ "task.outputs['contours_y'] = contours_y" + System.lineSeparator()
+ "task.outputs['rle'] = rle_masks" + System.lineSeparator();
code += "mask = 0" + System.lineSeparator();
code += "shm_mask.close()" + System.lineSeparator();
code += "shm_mask.unlink()" + System.lineSeparator();
if (shmArr != null) {
code += "shm_mask.close()" + System.lineSeparator();
code += "shm_mask.unlink()" + System.lineSeparator();
}
this.script = code;
}
}

0 comments on commit d7901c4

Please sign in to comment.