From 4c8250cfef4cd246399ec7b29bb9e07885a15ce6 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Tue, 3 Dec 2024 20:47:34 +0100 Subject: [PATCH] add batch processing in efficientsam --- .../ai/nets/samj/models/EfficientSamJ.java | 198 +++++++++++++----- 1 file changed, 146 insertions(+), 52 deletions(-) diff --git a/src/main/java/ai/nets/samj/models/EfficientSamJ.java b/src/main/java/ai/nets/samj/models/EfficientSamJ.java index c820e42..fc42f66 100644 --- a/src/main/java/ai/nets/samj/models/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientSamJ.java @@ -388,59 +388,153 @@ public String deleteEncodingScript(String encodingName) { @Override protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean returnAll) { - String code = ""; - code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); - code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; - for (long l : shmArr.getOriginalShape()) - code += l + ","; - code += "])" + System.lineSeparator(); - code += "different_mask_vals = np.unique(mask)" + System.lineSeparator(); - //code += "print(different_mask_vals)" + System.lineSeparator(); - code += "cont_x = []" + System.lineSeparator(); - code += "cont_y = []" + System.lineSeparator(); - code += "rle_masks = []" + System.lineSeparator(); - code += "for val in different_mask_vals:" + System.lineSeparator() - + " if val < 1:" + System.lineSeparator() - + " continue" + System.lineSeparator() - + " locations = np.where(mask == val)" + System.lineSeparator() - + " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator() - + " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator() - + " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator() - + " input_points_pos[:, 0] = locations[0]" + System.lineSeparator() - + " input_points_pos[:, 1] = locations[1]" + System.lineSeparator() - + " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator() - + " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator() - + " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator() - + " input_label = np.concatenate((input_labels_pos, input_labels_neg * 0), axis=0)" + System.lineSeparator() - + " input_points = torch.reshape(torch.tensor(input_points), [1, 1, -1, 2])" + System.lineSeparator() - + " input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator() - + " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() - + " input_points," + System.lineSeparator() - + " input_label," + System.lineSeparator() - + " multimask_output=True," + System.lineSeparator() - + " input_h=input_h," + System.lineSeparator() - + " input_w=input_w," + System.lineSeparator() - + " output_h=input_h," + System.lineSeparator() - + " output_w=input_w,)" + System.lineSeparator() - //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() - + " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator() - + " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator() - + " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator() - + " mask_val = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() - + (this.isIJROIManager ? "mask_val[1:, 1:] += mask_val[:-1, :-1]" : "") + System.lineSeparator() - + " cont_x_val,cont_y_val,rle_val = get_polygons_from_binary_mask(mask_val, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + " cont_x += cont_x_val" + System.lineSeparator() - + " cont_y += cont_y_val" + System.lineSeparator() - + " rle_masks += rle_val" + System.lineSeparator() - + "task.update('all contours traced')" + System.lineSeparator() - + "task.outputs['contours_x'] = cont_x" + System.lineSeparator() - + "task.outputs['contours_y'] = cont_y" + System.lineSeparator() - + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); + String code = "" + + "num_threads = 3" + System.lineSeparator() + + "finished_threads = []" + System.lineSeparator() + + "from concurrent.futures import ThreadPoolExecutor, as_completed" + System.lineSeparator() + + "import threading" + System.lineSeparator() + + "lock = threading.Lock()" + System.lineSeparator() + + "def respond_in_thread(task, args, inds, lock, finished_threads):" + System.lineSeparator() + + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + + " with lock:" + System.lineSeparator() + + " finished_threads.extend(inds)" + System.lineSeparator() + + "" + System.lineSeparator() + + "def cancel_unstarted_tasks(futures):" + System.lineSeparator() + + " for future in futures:" + System.lineSeparator() + + " if not future.running() and not future.done():" + System.lineSeparator() + + " future.cancel()" + System.lineSeparator() + + "" + System.lineSeparator() + + "num_features = 0" + System.lineSeparator(); + if (shmArr != null) { + code += "" + + "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator() + + "mask_batch = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; + for (long l : shmArr.getOriginalShape()) + code += l + ","; + code += "])" + System.lineSeparator(); + code += "labeled_array, num_features = label(mask_batch)" + System.lineSeparator(); + } + code += "" + + "contours_x = []" + System.lineSeparator() + + "contours_y = []" + System.lineSeparator() + + "rle_masks = []" + System.lineSeparator() + + "ntot = num_features + len(point_prompts) + len(rect_prompts)" + System.lineSeparator() + + "args = {\"outputs\": {'n': str(ntot)}, \"message\": '" + AbstractSamJ.UPDATE_ID_N_CONTOURS + "'}" + System.lineSeparator() + + "task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + // TODO right now is geetting the mask after each prompt + // TODO test processing first every prompt and then getting the masks + + "with ThreadPoolExecutor(max_workers=num_threads) as executor:" + System.lineSeparator() + + " futures = []" + System.lineSeparator() + + " n_objects = 0" + System.lineSeparator() + + " for n_feat in range(num_features):" + System.lineSeparator() + + " extracted_point_prompts = []" + System.lineSeparator() + + " extracted_point_labels = []" + System.lineSeparator() + + " inds = np.where(labeled_array == n_feat)" + System.lineSeparator() + + " n_points = np.min([3, inds[0].shape[0]])" + System.lineSeparator() + + " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator() + + " for pp in range(n_points):" + System.lineSeparator() + + " extracted_point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator() + + " extracted_point_labels += [n_feat]" + System.lineSeparator() + + " ip = torch.reshape(torch.tensor(np.array(extracted_point_prompts).reshape(1, 2)), [1, 1, -1, 2])" + System.lineSeparator() + + " il = torch.reshape(torch.tensor(np.array(extracted_point_labels)), [1, 1, -1])" + System.lineSeparator() + + " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() + + " ip," + System.lineSeparator() + + " il," + System.lineSeparator() + + " multimask_output=True," + System.lineSeparator() + + " input_h=input_h," + System.lineSeparator() + + " input_w=input_w," + System.lineSeparator() + + " output_h=input_h," + System.lineSeparator() + + " output_w=input_w,)" + System.lineSeparator() + + " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator() + + " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator() + + " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator() + + " mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() + + (this.isIJROIManager ? " mask[1:, 1:] += mask[:-1, :-1]" : "") + System.lineSeparator() + + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += c_x" + System.lineSeparator() + + " contours_y += c_y" + System.lineSeparator() + + " rle_masks += r_m" + System.lineSeparator() + + " args = {\"outputs\": {'temp_x': c_x, 'temp_y': c_y, 'temp_mask': r_m}, \"message\": '" + AbstractSamJ.UPDATE_ID_CONTOUR + "'}" + System.lineSeparator() + + " it_list = list(range(n_objects, n_objects := n_objects + len(r_m)))" + System.lineSeparator() + + " future = executor.submit(respond_in_thread, task, args, it_list, lock, finished_threads)" + System.lineSeparator() + + " futures.append(future)" + System.lineSeparator() + // TODO + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + + "" + System.lineSeparator() + + " for p_prompt in point_prompts:" + System.lineSeparator() + + " ip = torch.reshape(torch.tensor(np.array(p_prompt).reshape(1, 2)), [1, 1, -1, 2])" + System.lineSeparator() + + " il = torch.reshape(torch.tensor(np.array([1])), [1, 1, -1])" + System.lineSeparator() + + " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() + + " ip," + System.lineSeparator() + + " il," + System.lineSeparator() + + " multimask_output=True," + System.lineSeparator() + + " input_h=input_h," + System.lineSeparator() + + " input_w=input_w," + System.lineSeparator() + + " output_h=input_h," + System.lineSeparator() + + " output_w=input_w,)" + System.lineSeparator() + + " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator() + + " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator() + + " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator() + + " mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() + + (this.isIJROIManager ? " mask[1:, 1:] += mask[:-1, :-1]" : "") + System.lineSeparator() + + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += c_x" + System.lineSeparator() + + " contours_y += c_y" + System.lineSeparator() + + " rle_masks += r_m" + System.lineSeparator() + + " args = {\"outputs\": {'point': p_prompt, 'temp_x': c_x, 'temp_y': c_y, 'temp_mask': r_m}, \"message\": '" + AbstractSamJ.UPDATE_ID_CONTOUR + "'}" + System.lineSeparator() + + " it_list = list(range(n_objects, n_objects := n_objects + len(r_m)))" + System.lineSeparator() + + " future = executor.submit(respond_in_thread, task, args, it_list, lock, finished_threads)" + System.lineSeparator() + + " futures.append(future)" + System.lineSeparator() + // TODO + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + + "" + System.lineSeparator() + + "" + System.lineSeparator() + + " for rect_prompt in rect_prompts:" + System.lineSeparator() + + " input_box = np.array([[rect_prompt[0], rect_prompt[1]], [rect_prompt[2], rect_prompt[3]]])" + System.lineSeparator() + + " input_box = torch.reshape(torch.tensor(input_box), [1, 1, -1, 2])" + System.lineSeparator() + + " input_label = np.array([2,3])" + System.lineSeparator() + + " input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator() + + " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() + + " input_box," + System.lineSeparator() + + " input_label," + System.lineSeparator() + + " multimask_output=True," + System.lineSeparator() + + " input_h=input_h," + System.lineSeparator() + + " input_w=input_w," + System.lineSeparator() + + " output_h=input_h," + System.lineSeparator() + + " output_w=input_w,)" + System.lineSeparator() + + " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator() + + " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator() + + " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator() + + " mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() + + (this.isIJROIManager ? " mask[1:, 1:] += mask[:-1, :-1]" : "") + System.lineSeparator() + + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += c_x" + System.lineSeparator() + + " contours_y += c_y" + System.lineSeparator() + + " rle_masks += r_m" + System.lineSeparator() + + " args = {\"outputs\": {'rect': rect_prompt, 'temp_x': c_x, 'temp_y': c_y, 'temp_mask': r_m}, \"message\": '" + AbstractSamJ.UPDATE_ID_CONTOUR + "'}" + System.lineSeparator() + + " it_list = list(range(n_objects, n_objects := n_objects + len(r_m)))" + System.lineSeparator() + + " future = executor.submit(respond_in_thread, task, args, it_list, lock, finished_threads)" + System.lineSeparator() + + " futures.append(future)" + System.lineSeparator() + // TODO + " task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() + + "" + System.lineSeparator() + + "" + System.lineSeparator() + + " finished_threads.sort()" + System.lineSeparator() + + " cancel_unstarted_tasks(futures)" + System.lineSeparator() + + " for i, future in enumerate(futures[::-1]):" + System.lineSeparator() + + " if not future.cancelled():" + System.lineSeparator() + + " future.result()" + System.lineSeparator() + + " for i in finished_threads[::-1]:" + System.lineSeparator() + + " contours_x.pop(i)" + System.lineSeparator() + + " contours_y.pop(i)" + System.lineSeparator() + + " rle_masks.pop(i)" + System.lineSeparator() + + "" + System.lineSeparator() + + "task.update('all contours traced')" + System.lineSeparator() + + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() + + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() + + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); code += "mask = 0" + System.lineSeparator(); - code += "shm_mask.close()" + System.lineSeparator(); - code += "shm_mask.unlink()" + System.lineSeparator(); + if (shmArr != null) { + code += "shm_mask.close()" + System.lineSeparator(); + code += "shm_mask.unlink()" + System.lineSeparator(); + } this.script = code; } }