diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index 09862cd..7a75eab 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -463,49 +463,131 @@ public String deleteEncodingScript(String encodingName) { @Override protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean returnAll) { - String code = ""; - code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); - code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; - for (long l : shmArr.getOriginalShape()) - code += l + ","; - code += "])" + System.lineSeparator(); - code += "different_mask_vals = np.unique(mask)" + System.lineSeparator(); - code += "contours_x = []" + System.lineSeparator(); - code += "contours_y = []" + System.lineSeparator(); - code += "rle_masks = []" + System.lineSeparator(); - code += "for val in different_mask_vals:" + System.lineSeparator() - + " if val < 1:" + System.lineSeparator() - + " continue" + System.lineSeparator() - + " locations = np.where(mask == val)" + System.lineSeparator() - + " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator() - + " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator() - + " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator() - + " input_points_pos[:, 0] = locations[0]" + System.lineSeparator() - + " input_points_pos[:, 1] = locations[1]" + System.lineSeparator() - + " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator() - + " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator() - + " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator() - + " input_label = np.concatenate((input_labels_pos, input_labels_neg), axis=0)" + System.lineSeparator() - + " mask_val, _, _ = predictor.predict(" + System.lineSeparator() - + " point_coords=input_points," + System.lineSeparator() - + " point_labels=input_label," + System.lineSeparator() - + " multimask_output=False," + System.lineSeparator() - + " box=None,)" + System.lineSeparator() - //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() - + (this.isIJROIManager ? "mask_val[0, 1:, 1:] += mask_val[0, :-1, :-1]" : "") + System.lineSeparator() - + " contours_x_val,contours_y_val,rle_val = get_polygons_from_binary_mask(mask_val[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + " contours_x += contours_x_val" + System.lineSeparator() - + " contours_y += contours_y_val" + System.lineSeparator() - + " rle_masks += rle_val" + System.lineSeparator() - + "task.update('all contours traced')" + System.lineSeparator() - + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() - + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() - + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); + String code = "" + + "num_threads = 3" + System.lineSeparator() + + "finished_threads = []" + System.lineSeparator() + + "print(('threading' not in globals().keys()))" + System.lineSeparator() + + "if \"threading\" not in globals().keys():" + System.lineSeparator() + + " import threading" + System.lineSeparator() + + " globals()['threading'] = threading" + System.lineSeparator() + + "if \"ThreadPoolExecutor\" not in globals().keys():" + System.lineSeparator() + + " from concurrent.futures import ThreadPoolExecutor, as_completed" + System.lineSeparator() + + " globals()['ThreadPoolExecutor'] = ThreadPoolExecutor" + System.lineSeparator() + + " globals()['as_completed'] = as_completed" + System.lineSeparator() + + "lock = threading.Lock()" + System.lineSeparator() + + "def respond_in_thread(task, args, inds, lock, finished_threads):" + System.lineSeparator() + + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + + " with lock:" + System.lineSeparator() + + " finished_threads.extend(inds)" + System.lineSeparator() + + "" + System.lineSeparator() + + "def cancel_unstarted_tasks(futures):" + System.lineSeparator() + + " for future in futures:" + System.lineSeparator() + + " if not future.running() and not future.done():" + System.lineSeparator() + + " future.cancel()" + System.lineSeparator() + + "" + System.lineSeparator() + + "num_features = 0" + System.lineSeparator(); + if (shmArr != null) { + code += "" + + "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator() + + "mask_batch = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; + for (long l : shmArr.getOriginalShape()) + code += l + ","; + code += "])" + System.lineSeparator(); + code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator(); + } + code += "" + + "contours_x = []" + System.lineSeparator() + + "contours_y = []" + System.lineSeparator() + + "rle_masks = []" + System.lineSeparator() + + "ntot = num_features + len(point_prompts) + len(rect_prompts)" + System.lineSeparator() + + "args = {\"outputs\": {'n': str(ntot)}, \"message\": '" + AbstractSamJ.UPDATE_ID_N_CONTOURS + "'}" + System.lineSeparator() + + "task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + // TODO right now is geetting the mask after each prompt + // TODO test processing first every prompt and then getting the masks + + "with ThreadPoolExecutor(max_workers=num_threads) as executor:" + System.lineSeparator() + + " futures = []" + System.lineSeparator() + + " n_objects = 0" + System.lineSeparator() + + " for n_feat in range(num_features):" + System.lineSeparator() + + " extracted_point_prompts = []" + System.lineSeparator() + + " extracted_point_labels = []" + System.lineSeparator() + + " inds = np.where(labeled_array == n_feat)" + System.lineSeparator() + + " n_points = np.min([3, inds[0].shape[0]])" + System.lineSeparator() + + " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator() + + " for pp in range(n_points):" + System.lineSeparator() + + " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator() + + " extracted_point_labels += [n_feat]" + System.lineSeparator() + + " mask, _, _ = predictor.predict(" + System.lineSeparator() + + " point_coords=point_prompts," + System.lineSeparator() + + " point_labels=point_labels," + System.lineSeparator() + + " multimask_output=False," + System.lineSeparator() + + " box=None,)" + System.lineSeparator() + + (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator() + + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += c_x" + System.lineSeparator() + + " contours_y += c_y" + System.lineSeparator() + + " rle_masks += r_m" + System.lineSeparator() + + " args = {\"outputs\": {'temp_x': c_x, 'temp_y': c_y, 'temp_mask': r_m}, \"message\": '" + AbstractSamJ.UPDATE_ID_CONTOUR + "'}" + System.lineSeparator() + + " it_list = list(range(n_objects, n_objects := n_objects + len(r_m)))" + System.lineSeparator() + + " future = executor.submit(respond_in_thread, task, args, it_list, lock, finished_threads)" + System.lineSeparator() + + " futures.append(future)" + System.lineSeparator() + // TODO + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + + "" + System.lineSeparator() + + " for p_prompt in point_prompts:" + System.lineSeparator() + + " mask, _, _ = predictor.predict(" + System.lineSeparator() + + " point_coords=np.array(p_prompt).reshape(1, 2)," + System.lineSeparator() + + " point_labels=np.array([1])," + System.lineSeparator() + + " multimask_output=False," + System.lineSeparator() + + " box=None,)" + System.lineSeparator() + + (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator() + + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += c_x" + System.lineSeparator() + + " contours_y += c_y" + System.lineSeparator() + + " rle_masks += r_m" + System.lineSeparator() + + " args = {\"outputs\": {'point': p_prompt, 'temp_x': c_x, 'temp_y': c_y, 'temp_mask': r_m}, \"message\": '" + AbstractSamJ.UPDATE_ID_CONTOUR + "'}" + System.lineSeparator() + + " it_list = list(range(n_objects, n_objects := n_objects + len(r_m)))" + System.lineSeparator() + + " future = executor.submit(respond_in_thread, task, args, it_list, lock, finished_threads)" + System.lineSeparator() + + " futures.append(future)" + System.lineSeparator() + // TODO + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + + "" + System.lineSeparator() + + "" + System.lineSeparator() + + " for rect_prompt in rect_prompts:" + System.lineSeparator() + + " input_box = np.array([[rect_prompt[0], rect_prompt[1]], [rect_prompt[2], rect_prompt[3]]])" + System.lineSeparator() + + " mask, _, _ = predictor.predict(" + System.lineSeparator() + + " point_coords=None," + System.lineSeparator() + + " point_labels=np.array([1])," + System.lineSeparator() + + " multimask_output=False," + System.lineSeparator() + + " box=input_box,)" + System.lineSeparator() + + (this.isIJROIManager ? " mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator() + + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += c_x" + System.lineSeparator() + + " contours_y += c_y" + System.lineSeparator() + + " rle_masks += r_m" + System.lineSeparator() + + " args = {\"outputs\": {'rect': rect_prompt, 'temp_x': c_x, 'temp_y': c_y, 'temp_mask': r_m}, \"message\": '" + AbstractSamJ.UPDATE_ID_CONTOUR + "'}" + System.lineSeparator() + + " it_list = list(range(n_objects, n_objects := n_objects + len(r_m)))" + System.lineSeparator() + + " future = executor.submit(respond_in_thread, task, args, it_list, lock, finished_threads)" + System.lineSeparator() + + " futures.append(future)" + System.lineSeparator() + // TODO + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + + "" + System.lineSeparator() + + "" + System.lineSeparator() + + " finished_threads.sort()" + System.lineSeparator() + + " cancel_unstarted_tasks(futures)" + System.lineSeparator() + + " for i, future in enumerate(futures[::-1]):" + System.lineSeparator() + + " if not future.cancelled():" + System.lineSeparator() + + " future.result()" + System.lineSeparator() + + " for i in finished_threads[::-1]:" + System.lineSeparator() + + " contours_x.pop(i)" + System.lineSeparator() + + " contours_y.pop(i)" + System.lineSeparator() + + " rle_masks.pop(i)" + System.lineSeparator() + + "" + System.lineSeparator() + + "task.update('all contours traced')" + System.lineSeparator() + + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() + + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() + + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); code += "mask = 0" + System.lineSeparator(); - code += "shm_mask.close()" + System.lineSeparator(); - code += "shm_mask.unlink()" + System.lineSeparator(); + if (shmArr != null) { + code += "shm_mask.close()" + System.lineSeparator(); + code += "shm_mask.unlink()" + System.lineSeparator(); + } this.script = code; } }