From c09a6b314053457786d994be9fcb1d51604aaaa0 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Thu, 28 Nov 2024 14:55:53 +0100 Subject: [PATCH] adapt to use mask as input --- src/main/java/ai/nets/samj/models/Sam2.java | 71 ++++++++++----------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/src/main/java/ai/nets/samj/models/Sam2.java b/src/main/java/ai/nets/samj/models/Sam2.java index 75822e4..ad4eb59 100644 --- a/src/main/java/ai/nets/samj/models/Sam2.java +++ b/src/main/java/ai/nets/samj/models/Sam2.java @@ -20,7 +20,6 @@ package ai.nets.samj.models; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import ai.nets.samj.install.Sam2EnvManager; @@ -34,7 +33,6 @@ import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray; import io.bioimage.modelrunner.utils.CommonUtils; -import net.imglib2.Cursor; import net.imglib2.RandomAccessibleInterval; import net.imglib2.converter.RealTypeConverters; import net.imglib2.img.array.ArrayImgs; @@ -78,6 +76,7 @@ public class Sam2 extends AbstractSamJ { + "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator() + "import torch" + System.lineSeparator() + "from scipy.ndimage import binary_fill_holes" + System.lineSeparator() + + "from scipy.ndimage import label" + System.lineSeparator() + "import sys" + System.lineSeparator() + "import os" + System.lineSeparator() + "from multiprocessing import shared_memory" + System.lineSeparator() @@ -323,40 +322,40 @@ protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) for (long l : shmArr.getOriginalShape()) code += l + ","; code += "])" + System.lineSeparator(); - code += "different_mask_vals = np.unique(mask)" + System.lineSeparator(); - code += "contours_x = []" + System.lineSeparator(); - code += "contours_y = []" + System.lineSeparator(); - code += "rle_masks = []" + System.lineSeparator(); - code += "for val in different_mask_vals:" + System.lineSeparator() - + " if val < 1:" + System.lineSeparator() - + " continue" + System.lineSeparator() - + " locations = np.where(mask == val)" + System.lineSeparator() - + " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator() - + " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator() - + " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator() - + " input_points_pos[:, 0] = locations[0]" + System.lineSeparator() - + " input_points_pos[:, 1] = locations[1]" + System.lineSeparator() - + " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator() - + " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator() - + " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator() - + " input_label = np.concatenate((input_labels_pos, input_labels_neg), axis=0)" + System.lineSeparator() - + " mask_val, _, _ = predictor.predict(" + System.lineSeparator() - + " point_coords=input_points," + System.lineSeparator() - + " point_labels=input_label," + System.lineSeparator() - + " multimask_output=False," + System.lineSeparator() - + " box=None,)" + System.lineSeparator() - //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() - + (this.isIJROIManager ? "mask_val[0, 1:, 1:] += mask_val[0, :-1, :-1]" : "") + System.lineSeparator() - + " contours_x_val,contours_y_val, rles_vals = get_polygons_from_binary_mask(mask_val[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + " contours_x += contours_x_val" + System.lineSeparator() - + " contours_y += contours_y_val" + System.lineSeparator() - + " rle_masks += rles_vals" + System.lineSeparator() - + "task.update('all contours traced')" + System.lineSeparator() - + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() - + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() - + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); + code += "" + + "point_prompts = []" + System.lineSeparator() + + "point_labels = []" + System.lineSeparator() + + "labeled_array, num_features = label(mask)" + System.lineSeparator() + + "for n_feat in range(num_features):" + System.lineSeparator() + + " inds = np.where(labeled_array == n_feat)" + System.lineSeparator() + + " n_points = np.min([3, inds[0].shape[0]])" + System.lineSeparator() + + " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator() + + " for pp in range(n_points):" + System.lineSeparator() + + " point_prompts += [[inds[0][random_posiitons[pp]], inds[1][random_posiitons[pp]]]]" + System.lineSeparator() + + " point_labels += [n_feat]" + System.lineSeparator() + + "" + System.lineSeparator() + + "input_points = np.array(input_points_list)" + System.lineSeparator() + + "input_label = np.array(point_labels)" + System.lineSeparator() + + "mask, _, _ = predictor.predict(" + System.lineSeparator() + + " point_coords=input_points," + System.lineSeparator() + + " point_labels=input_label," + System.lineSeparator() + + " multimask_output=False," + System.lineSeparator() + + " box=None,)" + System.lineSeparator() + + "" + System.lineSeparator() + + "for b in range(num_features):" + + " " + System.lineSeparator() + + "" + System.lineSeparator() + + "" + System.lineSeparator() + // TODO remove + "import matplotlib.pyplot as plt" + System.lineSeparator() + // TODO remove + "plt.imsave('/tmp/aa.jpg', mask[0], cmap='gray')" + System.lineSeparator() + + (this.isIJROIManager ? "mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator() + //+ (this.isIJROIManager ? "mask[0, :, 1:] += mask[0, :, :-1]" : "") + System.lineSeparator() + //+ "np.save('/home/carlos/git/aa.npy', mask)" + System.lineSeparator() + + "contours_x, contours_y, rle_masks = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + "task.update('all contours traced')" + System.lineSeparator() + + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() + + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() + + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); code += "mask = 0" + System.lineSeparator(); code += "shm_mask.close()" + System.lineSeparator(); code += "shm_mask.unlink()" + System.lineSeparator();