diff --git a/src/main/java/ai/nets/samj/communication/model/SAMModel.java b/src/main/java/ai/nets/samj/communication/model/SAMModel.java index 7f8eab5..6bbea8f 100644 --- a/src/main/java/ai/nets/samj/communication/model/SAMModel.java +++ b/src/main/java/ai/nets/samj/communication/model/SAMModel.java @@ -175,9 +175,14 @@ public void setReturnOnlyBiggest(boolean onlyBiggest) { this.onlyBiggest = onlyBiggest; } - public List processBatchOfPoints(List points) throws IOException, RuntimeException, InterruptedException { - //return samj.processBathcOfPoints(points, !onlyBiggest); - return null; + public List processBatchOfPoints(List points) throws IOException, RuntimeException, InterruptedException { + return samj.processBatchOfPoints(points, !onlyBiggest); + } + + public & NativeType> + List processBatchOfPrompts(List points, List rects, RandomAccessibleInterval rai) + throws IOException, RuntimeException, InterruptedException { + return samj.processBatchOfPrompts(points, rects, rai, !onlyBiggest); } /** diff --git a/src/main/java/ai/nets/samj/models/AbstractSamJ.java b/src/main/java/ai/nets/samj/models/AbstractSamJ.java index 192d5d7..db612e2 100644 --- a/src/main/java/ai/nets/samj/models/AbstractSamJ.java +++ b/src/main/java/ai/nets/samj/models/AbstractSamJ.java @@ -167,8 +167,8 @@ public interface DebugTextPrinter { void printText(String text); } protected abstract void cellSAM(List grid, boolean returnAll); - protected abstract & NativeType> void - processPromptsBatchWithSAM(List points, List rects, RandomAccessibleInterval rai, boolean returnAll); + protected abstract void + processPromptsBatchWithSAM(List points, List rects, SharedMemoryArray shmArr, boolean returnAll); protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll); @@ -178,8 +178,6 @@ public interface DebugTextPrinter { void printText(String text); } protected abstract & NativeType> void createEncodeImageScript(); - abstract protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll); - protected abstract & NativeType> void createSHMArray(RandomAccessibleInterval imShared); @Override @@ -418,29 +416,49 @@ List processBatchOfPrompts(List points, List rects, Rand public & NativeType> List processBatchOfPrompts(List pointsList, List rects, RandomAccessibleInterval rai, boolean returnAll) throws IOException, RuntimeException, InterruptedException { - checkPrompts(pointsList, rects, rai); if ((pointsList == null || pointsList.size() == 0) && (rects == null || rects.size() == 0) && (rai == null)) return new ArrayList(); + checkPrompts(pointsList, rects, rai); + // TODO adapt to reencoding for big images, ideally it should process points close together together pointsList = adaptPointPrompts(pointsList); + // TODO adapt rect prompts this.script = ""; - processPromptsBatchWithSAM(pointsList, null, null, returnAll); - printScript(script, "Points and negative points inference"); - List polys = processAndRetrieveContours(null); - recalculatePolys(polys, encodeCoords); - return polys; + SharedMemoryArray maskShma = null; + if (rai != null) + maskShma = SharedMemoryArray.createSHMAFromRAI(rai, false, false); + + try { + processPromptsBatchWithSAM(pointsList, rects, maskShma, returnAll); + printScript(script, "Batch of prompts inference"); + List polys = processAndRetrieveContours(null); + recalculatePolys(polys, encodeCoords); + return polys; + } catch (IOException | RuntimeException | InterruptedException ex) { + maskShma.close(); + throw ex; + } } private & NativeType> void checkPrompts(List pointsList, List rects, RandomAccessibleInterval rai) { + long[] dims; if ((pointsList == null || pointsList.size() == 0) && (rects == null || rects.size() == 0) - && !(rai.getType() instanceof IntegerType)) + && rai != null && !(rai.getType() instanceof IntegerType)) { throw new IllegalArgumentException("The mask provided should be of any integer type."); - else if ((pointsList == null || pointsList.size() == 0) + } else if ((pointsList == null || pointsList.size() == 0) && (rects == null || rects.size() == 0) - && !(rai.getType() instanceof IntegerType)) { - throw new IllegalArgumentException("The mask provided should be of the same size as the image of interest."); + && rai != null) { + dims = rai.dimensionsAsLongArray(); + if ((dims.length == 2 || (dims.length == 3 && dims[2] == 1)) + && dims[1] == this.shma.getOriginalShape()[0] && dims[0] == this.shma.getOriginalShape()[1]) { + rai = Views.permute(rai, 0, 1); + } else if (dims[0] != this.shma.getOriginalShape()[0] && dims[1] != this.shma.getOriginalShape()[1] + || (dims.length == 3 && dims[2] != 1) || dims.length > 3) { + throw new IllegalArgumentException("The provided mask should be a 2d image with just one channel of width " + + this.shma.getOriginalShape()[1] + " and height " + this.shma.getOriginalShape()[0]); + } } } @@ -730,32 +748,9 @@ List processMask(RandomAccessibleInterval img) * @throws InterruptedException if the process in interrupted */ public & NativeType> - List processMask(RandomAccessibleInterval img, boolean returnAll) + List processMask(RandomAccessibleInterval rai, boolean returnAll) throws IOException, RuntimeException, InterruptedException { - long[] dims = img.dimensionsAsLongArray(); - if ((dims.length == 2 || (dims.length == 3 && dims[2] == 1)) - && dims[1] == this.shma.getOriginalShape()[0] && dims[0] == this.shma.getOriginalShape()[1]) { - img = Views.permute(img, 0, 1); - } else if (dims[0] != this.shma.getOriginalShape()[0] && dims[1] != this.shma.getOriginalShape()[1] - || (dims.length == 3 && dims[2] != 1) || dims.length > 3) { - throw new IllegalArgumentException("The provided mask should be a 2d image with just one channel of width " - + this.shma.getOriginalShape()[1] + " and height " + this.shma.getOriginalShape()[0]); - } - SharedMemoryArray maskShma = SharedMemoryArray.createSHMAFromRAI(img, false, false); - try { - return processMask(maskShma, returnAll); - } catch (IOException | RuntimeException | InterruptedException ex) { - maskShma.close(); - throw ex; - } - } - - private List processMask(SharedMemoryArray shmArr, boolean returnAll) - throws IOException, RuntimeException, InterruptedException { - this.script = ""; - processMasksWithSam(shmArr, returnAll); - printScript(script, "Pre-computed mask inference"); - List polys = processAndRetrieveContours(null); + List polys = processBatchOfPrompts(null, null, rai, returnAll); debugPrinter.printText("processMask() obtained " + polys.size() + " polygons"); return polys; } diff --git a/src/main/java/ai/nets/samj/models/EfficientSamJ.java b/src/main/java/ai/nets/samj/models/EfficientSamJ.java index 53db12f..18a2ecf 100644 --- a/src/main/java/ai/nets/samj/models/EfficientSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientSamJ.java @@ -226,64 +226,6 @@ protected void createEncodeImageScript() { + "_ = predictor.get_image_embeddings(im[None, ...])" + System.lineSeparator(); } - @Override - protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) { - String code = ""; - code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); - code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; - for (long l : shmArr.getOriginalShape()) - code += l + ","; - code += "])" + System.lineSeparator(); - code += "different_mask_vals = np.unique(mask)" + System.lineSeparator(); - //code += "print(different_mask_vals)" + System.lineSeparator(); - code += "cont_x = []" + System.lineSeparator(); - code += "cont_y = []" + System.lineSeparator(); - code += "rle_masks = []" + System.lineSeparator(); - code += "for val in different_mask_vals:" + System.lineSeparator() - + " if val < 1:" + System.lineSeparator() - + " continue" + System.lineSeparator() - + " locations = np.where(mask == val)" + System.lineSeparator() - + " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator() - + " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator() - + " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator() - + " input_points_pos[:, 0] = locations[0]" + System.lineSeparator() - + " input_points_pos[:, 1] = locations[1]" + System.lineSeparator() - + " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator() - + " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator() - + " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator() - + " input_label = np.concatenate((input_labels_pos, input_labels_neg * 0), axis=0)" + System.lineSeparator() - + " input_points = torch.reshape(torch.tensor(input_points), [1, 1, -1, 2])" + System.lineSeparator() - + " input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator() - + " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() - + " input_points," + System.lineSeparator() - + " input_label," + System.lineSeparator() - + " multimask_output=True," + System.lineSeparator() - + " input_h=input_h," + System.lineSeparator() - + " input_w=input_w," + System.lineSeparator() - + " output_h=input_h," + System.lineSeparator() - + " output_w=input_w,)" + System.lineSeparator() - //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() - + " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator() - + " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator() - + " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator() - + " mask_val = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() - + (this.isIJROIManager ? "mask_val[1:, 1:] += mask_val[:-1, :-1]" : "") + System.lineSeparator() - + " cont_x_val,cont_y_val,rle_val = get_polygons_from_binary_mask(mask_val, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + " cont_x += cont_x_val" + System.lineSeparator() - + " cont_y += cont_y_val" + System.lineSeparator() - + " rle_masks += rle_val" + System.lineSeparator() - + "task.update('all contours traced')" + System.lineSeparator() - + "task.outputs['contours_x'] = cont_x" + System.lineSeparator() - + "task.outputs['contours_y'] = cont_y" + System.lineSeparator() - + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); - code += "mask = 0" + System.lineSeparator(); - code += "shm_mask.close()" + System.lineSeparator(); - code += "shm_mask.unlink()" + System.lineSeparator(); - this.script = code; - } - @Override protected void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll) { String code = "" + System.lineSeparator() @@ -445,9 +387,61 @@ public String deleteEncodingScript(String encodingName) { } @Override - protected & NativeType> void processPromptsBatchWithSAM(List points, - List rects, RandomAccessibleInterval rai, boolean returnAll) { - // TODO Auto-generated method stub - + protected void processPromptsBatchWithSAM(List points, List rects, SharedMemoryArray shmArr, + boolean returnAll) { + String code = ""; + code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); + code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; + for (long l : shmArr.getOriginalShape()) + code += l + ","; + code += "])" + System.lineSeparator(); + code += "different_mask_vals = np.unique(mask)" + System.lineSeparator(); + //code += "print(different_mask_vals)" + System.lineSeparator(); + code += "cont_x = []" + System.lineSeparator(); + code += "cont_y = []" + System.lineSeparator(); + code += "rle_masks = []" + System.lineSeparator(); + code += "for val in different_mask_vals:" + System.lineSeparator() + + " if val < 1:" + System.lineSeparator() + + " continue" + System.lineSeparator() + + " locations = np.where(mask == val)" + System.lineSeparator() + + " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator() + + " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator() + + " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator() + + " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator() + + " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator() + + " input_points_pos[:, 0] = locations[0]" + System.lineSeparator() + + " input_points_pos[:, 1] = locations[1]" + System.lineSeparator() + + " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator() + + " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator() + + " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator() + + " input_label = np.concatenate((input_labels_pos, input_labels_neg * 0), axis=0)" + System.lineSeparator() + + " input_points = torch.reshape(torch.tensor(input_points), [1, 1, -1, 2])" + System.lineSeparator() + + " input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator() + + " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() + + " input_points," + System.lineSeparator() + + " input_label," + System.lineSeparator() + + " multimask_output=True," + System.lineSeparator() + + " input_h=input_h," + System.lineSeparator() + + " input_w=input_w," + System.lineSeparator() + + " output_h=input_h," + System.lineSeparator() + + " output_w=input_w,)" + System.lineSeparator() + //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() + + " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator() + + " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator() + + " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator() + + " mask_val = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() + + (this.isIJROIManager ? "mask_val[1:, 1:] += mask_val[:-1, :-1]" : "") + System.lineSeparator() + + " cont_x_val,cont_y_val,rle_val = get_polygons_from_binary_mask(mask_val, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " cont_x += cont_x_val" + System.lineSeparator() + + " cont_y += cont_y_val" + System.lineSeparator() + + " rle_masks += rle_val" + System.lineSeparator() + + "task.update('all contours traced')" + System.lineSeparator() + + "task.outputs['contours_x'] = cont_x" + System.lineSeparator() + + "task.outputs['contours_y'] = cont_y" + System.lineSeparator() + + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); + code += "mask = 0" + System.lineSeparator(); + code += "shm_mask.close()" + System.lineSeparator(); + code += "shm_mask.unlink()" + System.lineSeparator(); + this.script = code; } } diff --git a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java index 0b191da..0c63fcd 100644 --- a/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java +++ b/src/main/java/ai/nets/samj/models/EfficientViTSamJ.java @@ -333,54 +333,6 @@ protected & NativeType> void createSHMArray(RandomAcce adaptImageToModel(imageToBeSent, shma.getSharedRAI()); } - @Override - protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) { - String code = ""; - code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); - code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; - for (long l : shmArr.getOriginalShape()) - code += l + ","; - code += "])" + System.lineSeparator(); - code += "different_mask_vals = np.unique(mask)" + System.lineSeparator(); - code += "contours_x = []" + System.lineSeparator(); - code += "contours_y = []" + System.lineSeparator(); - code += "rle_masks = []" + System.lineSeparator(); - code += "for val in different_mask_vals:" + System.lineSeparator() - + " if val < 1:" + System.lineSeparator() - + " continue" + System.lineSeparator() - + " locations = np.where(mask == val)" + System.lineSeparator() - + " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator() - + " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator() - + " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator() - + " input_points_pos[:, 0] = locations[0]" + System.lineSeparator() - + " input_points_pos[:, 1] = locations[1]" + System.lineSeparator() - + " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator() - + " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator() - + " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator() - + " input_label = np.concatenate((input_labels_pos, input_labels_neg), axis=0)" + System.lineSeparator() - + " mask_val, _, _ = predictor.predict(" + System.lineSeparator() - + " point_coords=input_points," + System.lineSeparator() - + " point_labels=input_label," + System.lineSeparator() - + " multimask_output=False," + System.lineSeparator() - + " box=None,)" + System.lineSeparator() - //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() - + (this.isIJROIManager ? "mask_val[0, 1:, 1:] += mask_val[0, :-1, :-1]" : "") + System.lineSeparator() - + " contours_x_val,contours_y_val,rle_val = get_polygons_from_binary_mask(mask_val[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + " contours_x += contours_x_val" + System.lineSeparator() - + " contours_y += contours_y_val" + System.lineSeparator() - + " rle_masks += rle_val" + System.lineSeparator() - + "task.update('all contours traced')" + System.lineSeparator() - + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() - + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() - + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); - code += "mask = 0" + System.lineSeparator(); - code += "shm_mask.close()" + System.lineSeparator(); - code += "shm_mask.unlink()" + System.lineSeparator(); - this.script = code; - } - @Override protected void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll) { String code = "" + System.lineSeparator() @@ -510,9 +462,51 @@ public String deleteEncodingScript(String encodingName) { } @Override - protected & NativeType> void processPromptsBatchWithSAM(List points, - List rects, RandomAccessibleInterval rai, boolean returnAll) { - // TODO Auto-generated method stub - + protected void processPromptsBatchWithSAM(List points, List rects, SharedMemoryArray shmArr, + boolean returnAll) { + String code = ""; + code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); + code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; + for (long l : shmArr.getOriginalShape()) + code += l + ","; + code += "])" + System.lineSeparator(); + code += "different_mask_vals = np.unique(mask)" + System.lineSeparator(); + code += "contours_x = []" + System.lineSeparator(); + code += "contours_y = []" + System.lineSeparator(); + code += "rle_masks = []" + System.lineSeparator(); + code += "for val in different_mask_vals:" + System.lineSeparator() + + " if val < 1:" + System.lineSeparator() + + " continue" + System.lineSeparator() + + " locations = np.where(mask == val)" + System.lineSeparator() + + " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator() + + " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator() + + " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator() + + " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator() + + " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator() + + " input_points_pos[:, 0] = locations[0]" + System.lineSeparator() + + " input_points_pos[:, 1] = locations[1]" + System.lineSeparator() + + " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator() + + " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator() + + " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator() + + " input_label = np.concatenate((input_labels_pos, input_labels_neg), axis=0)" + System.lineSeparator() + + " mask_val, _, _ = predictor.predict(" + System.lineSeparator() + + " point_coords=input_points," + System.lineSeparator() + + " point_labels=input_label," + System.lineSeparator() + + " multimask_output=False," + System.lineSeparator() + + " box=None,)" + System.lineSeparator() + //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() + + (this.isIJROIManager ? "mask_val[0, 1:, 1:] += mask_val[0, :-1, :-1]" : "") + System.lineSeparator() + + " contours_x_val,contours_y_val,rle_val = get_polygons_from_binary_mask(mask_val[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += contours_x_val" + System.lineSeparator() + + " contours_y += contours_y_val" + System.lineSeparator() + + " rle_masks += rle_val" + System.lineSeparator() + + "task.update('all contours traced')" + System.lineSeparator() + + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() + + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() + + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); + code += "mask = 0" + System.lineSeparator(); + code += "shm_mask.close()" + System.lineSeparator(); + code += "shm_mask.unlink()" + System.lineSeparator(); + this.script = code; } } diff --git a/src/main/java/ai/nets/samj/models/Sam2.java b/src/main/java/ai/nets/samj/models/Sam2.java index 9b1c210..ecdcd4a 100644 --- a/src/main/java/ai/nets/samj/models/Sam2.java +++ b/src/main/java/ai/nets/samj/models/Sam2.java @@ -316,64 +316,6 @@ protected & NativeType> void createSHMArray(RandomAcce adaptImageToModel(imageToBeSent, shma.getSharedRAI()); } - @Override - protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) { - String code = ""; - code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); - code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; - for (long l : shmArr.getOriginalShape()) - code += l + ","; - code += "])" + System.lineSeparator(); - code += "" - + "point_prompts = []" + System.lineSeparator() - + "point_labels = []" + System.lineSeparator() - + "labeled_array, num_features = label(mask)" + System.lineSeparator() - + "contours_x = []" + System.lineSeparator() - + "contours_y = []" + System.lineSeparator() - + "rle_masks = []" + System.lineSeparator() - // TODO right now is geetting the mask after each prompt - // TODO test processing first every prompt and then getting the masks - + "" + System.lineSeparator() - + "for n_feat in range(num_features):" + System.lineSeparator() - + " inds = np.where(labeled_array == n_feat)" + System.lineSeparator() - + " n_points = np.min([3, inds[0].shape[0]])" + System.lineSeparator() - + " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator() - + " for pp in range(n_points):" + System.lineSeparator() - + " point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator() - + " point_labels += [n_feat]" + System.lineSeparator() - + "" + System.lineSeparator() - + " mask, _, _ = predictor.predict(" + System.lineSeparator() - + " point_coords=point_prompts," + System.lineSeparator() - + " point_labels=point_labels," + System.lineSeparator() - + " multimask_output=False," + System.lineSeparator() - + " box=None,)" + System.lineSeparator() - + "" + System.lineSeparator() - + "" + System.lineSeparator() - + "" + System.lineSeparator() - // TODO + "for b in range(num_features):" + System.lineSeparator() - // TODO + " mm = mask[b]" + System.lineSeparator() - + (this.isIJROIManager ? " mask += mask[0, :-1, :-1]" : "") + System.lineSeparator() - + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + " contours_x += c_x" + System.lineSeparator() - + " contours_y += c_Y" + System.lineSeparator() - + " rle_masks += r_m" + System.lineSeparator() - + "" + System.lineSeparator() - + "" + System.lineSeparator() - // TODO remove + "import matplotlib.pyplot as plt" + System.lineSeparator() - // TODO remove + "plt.imsave('/tmp/aa.jpg', mask[0], cmap='gray')" + System.lineSeparator() - //+ (this.isIJROIManager ? "mask[0, :, 1:] += mask[0, :, :-1]" : "") + System.lineSeparator() - //+ "np.save('/home/carlos/git/aa.npy', mask)" + System.lineSeparator() - + "contours_x, contours_y, rle_masks = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + "task.update('all contours traced')" + System.lineSeparator() - + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() - + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() - + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); - code += "mask = 0" + System.lineSeparator(); - code += "shm_mask.close()" + System.lineSeparator(); - code += "shm_mask.unlink()" + System.lineSeparator(); - this.script = code; - } - @Override protected void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll) { String code = "" + System.lineSeparator() @@ -504,9 +446,82 @@ public String deleteEncodingScript(String encodingName) { } @Override - protected & NativeType> void processPromptsBatchWithSAM(List points, - List rects, RandomAccessibleInterval rai, boolean returnAll) { - // TODO Auto-generated method stub - + protected void processPromptsBatchWithSAM(List points, + List rects, SharedMemoryArray shmArr, boolean returnAll) { + String code = ""; + code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); + code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; + for (long l : shmArr.getOriginalShape()) + code += l + ","; + code += "])" + System.lineSeparator(); + code += "" + + "point_prompts = []" + System.lineSeparator() + + "point_labels = []" + System.lineSeparator() + + "mask_prompts = []" + System.lineSeparator() + + "mask_labels = []" + System.lineSeparator() + + "labeled_array, num_features = label(mask)" + System.lineSeparator() + + "contours_x = []" + System.lineSeparator() + + "contours_y = []" + System.lineSeparator() + + "rle_masks = []" + System.lineSeparator() + // TODO right now is geetting the mask after each prompt + // TODO test processing first every prompt and then getting the masks + + "" + System.lineSeparator() + + "for n_feat in range(num_features):" + System.lineSeparator() + + " inds = np.where(labeled_array == n_feat)" + System.lineSeparator() + + " n_points = np.min([3, inds[0].shape[0]])" + System.lineSeparator() + + " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator() + + " for pp in range(n_points):" + System.lineSeparator() + + " point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator() + + " point_labels += [n_feat]" + System.lineSeparator() + + " mask, _, _ = predictor.predict(" + System.lineSeparator() + + " point_coords=point_prompts," + System.lineSeparator() + + " point_labels=point_labels," + System.lineSeparator() + + " multimask_output=False," + System.lineSeparator() + + " box=None,)" + System.lineSeparator() + + (this.isIJROIManager ? " mask += mask[0, :-1, :-1]" : "") + System.lineSeparator() + + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += c_x" + System.lineSeparator() + + " contours_y += c_Y" + System.lineSeparator() + + " rle_masks += r_m" + System.lineSeparator() + + "" + System.lineSeparator() + + "for p_prompt in range(len(point_prompts)):" + System.lineSeparator() + + " mask, _, _ = predictor.predict(" + System.lineSeparator() + + " point_coords=np.array(p_prompt).reshape(1, 2)," + System.lineSeparator() + + " point_labels=np.array([1])," + System.lineSeparator() + + " multimask_output=False," + System.lineSeparator() + + " box=None,)" + System.lineSeparator() + + (this.isIJROIManager ? " mask += mask[0, :-1, :-1]" : "") + System.lineSeparator() + + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += c_x" + System.lineSeparator() + + " contours_y += c_Y" + System.lineSeparator() + + " rle_masks += r_m" + System.lineSeparator() + + "" + System.lineSeparator() + + "" + System.lineSeparator() + + "for rect_prompt in range(len(rect_prompts)):" + System.lineSeparator() + + " input_box = np.array([[input_box[0], input_box[1]], [input_box[2], input_box[3]]])" + System.lineSeparator() + + " mask, _, _ = predictor.predict(" + System.lineSeparator() + + " point_coords=None," + System.lineSeparator() + + " point_labels=np.array([1])," + System.lineSeparator() + + " multimask_output=False," + System.lineSeparator() + + " box=input_box,)" + System.lineSeparator() + + (this.isIJROIManager ? " mask += mask[0, :-1, :-1]" : "") + System.lineSeparator() + + " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() + + " contours_x += c_x" + System.lineSeparator() + + " contours_y += c_Y" + System.lineSeparator() + + " rle_masks += r_m" + System.lineSeparator() + + "" + System.lineSeparator() + + "" + System.lineSeparator() + // TODO remove + "import matplotlib.pyplot as plt" + System.lineSeparator() + // TODO remove + "plt.imsave('/tmp/aa.jpg', mask[0], cmap='gray')" + System.lineSeparator() + //+ (this.isIJROIManager ? "mask[0, :, 1:] += mask[0, :, :-1]" : "") + System.lineSeparator() + //+ "np.save('/home/carlos/git/aa.npy', mask)" + System.lineSeparator() + + "task.update('all contours traced')" + System.lineSeparator() + + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() + + "task.outputs['contours_y'] = contours_y" + System.lineSeparator() + + "task.outputs['rle'] = rle_masks" + System.lineSeparator(); + code += "mask = 0" + System.lineSeparator(); + code += "shm_mask.close()" + System.lineSeparator(); + code += "shm_mask.unlink()" + System.lineSeparator(); + this.script = code; } }