diff --git a/src/main/java/ai/nets/samj/AbstractSamJ2.java b/src/main/java/ai/nets/samj/AbstractSamJ2.java index 94e98d5..87920e8 100644 --- a/src/main/java/ai/nets/samj/AbstractSamJ2.java +++ b/src/main/java/ai/nets/samj/AbstractSamJ2.java @@ -156,8 +156,40 @@ void updateImage(RandomAccessibleInterval rai) throws IOException, RuntimeExc * @throws RuntimeException if there is any error running the Python code * @throws InterruptedException if the process is interrupted */ - public abstract & NativeType> - void addImage(RandomAccessibleInterval rai); + private & NativeType> + void addImage(RandomAccessibleInterval rai) + throws IOException, RuntimeException, InterruptedException { + if (rai.dimensionsAsLongArray()[0] * rai.dimensionsAsLongArray()[1] > MAX_ENCODED_AREA_RS * MAX_ENCODED_AREA_RS + || rai.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || rai.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) { + this.targetDims = new long[] {0, 0, 0}; + this.img = rai; + return; + } + this.script = ""; + sendImgLib2AsNp(rai); + createEncodeImageScript(); + try { + printScript(script, "Creation of initial embeddings"); + Task task = python.task(script); + task.waitFor(); + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException(); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(); + this.shma.close(); + } catch (IOException | InterruptedException | RuntimeException e) { + try { + this.shma.close(); + } catch (IOException e1) { + throw new IOException(e.toString() + System.lineSeparator() + e1.toString()); + } + throw e; + } + } + + protected abstract void createEncodeImageScript(); private void reencodeCrop() throws IOException, InterruptedException, RuntimeException { reencodeCrop(null); @@ -166,9 +198,7 @@ private void reencodeCrop() throws IOException, InterruptedException, RuntimeExc private void reencodeCrop(long[] cropSize) throws IOException, InterruptedException, RuntimeException { this.script = ""; sendCropAsNp(cropSize); - this.script += "" - + "task.update(str(im.shape))" + System.lineSeparator() - + "aa = predictor.get_image_embeddings(im[None, ...])"; + createEncodeImageScript(); try { printScript(script, "Creation of the cropped embeddings"); Task task = python.task(script); @@ -299,58 +329,7 @@ private List processMask(SharedMemoryArray shmArr, boolean returnAll) return polys; } - private void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) { - String code = ""; - code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator(); - code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape(["; - for (long l : shmArr.getOriginalShape()) - code += l + ","; - code += "])" + System.lineSeparator(); - code += "different_mask_vals = np.unique(mask)" + System.lineSeparator(); - //code += "print(different_mask_vals)" + System.lineSeparator(); - code += "cont_x = []" + System.lineSeparator(); - code += "cont_y = []" + System.lineSeparator(); - code += "for val in different_mask_vals:" + System.lineSeparator() - + " if val < 1:" + System.lineSeparator() - + " continue" + System.lineSeparator() - + " locations = np.where(mask == val)" + System.lineSeparator() - + " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator() - + " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator() - + " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator() - + " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator() - + " input_points_pos[:, 0] = locations[0]" + System.lineSeparator() - + " input_points_pos[:, 1] = locations[1]" + System.lineSeparator() - + " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator() - + " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator() - + " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator() - + " input_label = np.concatenate((input_labels_pos, input_labels_neg * 0), axis=0)" + System.lineSeparator() - + " input_points = torch.reshape(torch.tensor(input_points), [1, 1, -1, 2])" + System.lineSeparator() - + " input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator() - + " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() - + " input_points," + System.lineSeparator() - + " input_label," + System.lineSeparator() - + " multimask_output=True," + System.lineSeparator() - + " input_h=input_h," + System.lineSeparator() - + " input_w=input_w," + System.lineSeparator() - + " output_h=input_h," + System.lineSeparator() - + " output_w=input_w,)" + System.lineSeparator() - //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() - + " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator() - + " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator() - + " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator() - + " mask_val = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() - + " cont_x_val,cont_y_val = get_polygons_from_binary_mask(mask_val, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + " cont_x += cont_x_val" + System.lineSeparator() - + " cont_y += cont_y_val" + System.lineSeparator() - + "task.update('all contours traced')" + System.lineSeparator() - + "task.outputs['contours_x'] = cont_x" + System.lineSeparator() - + "task.outputs['contours_y'] = cont_y" + System.lineSeparator(); - code += "mask = 0" + System.lineSeparator(); - code += "shm_mask.close()" + System.lineSeparator(); - code += "shm_mask.unlink()" + System.lineSeparator(); - this.script = code; - } + abstract protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll); /** * Method used that runs EfficientSAM using a list of points as the prompt. This method runs @@ -829,73 +808,9 @@ private & NativeType> void sendCropAsNp(long[] cropSiz this.script += code; } - private void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll) { - String code = "" + System.lineSeparator() - + "task.update('start predict')" + System.lineSeparator() - + "input_points_list = []" + System.lineSeparator() - + "input_neg_points_list = []" + System.lineSeparator(); - for (int n = 0; n < nPoints; n ++) - code += "input_points_list.append([input_points[" + n + "][0], input_points[" + n + "][1]])" + System.lineSeparator(); - for (int n = 0; n < nNegPoints; n ++) - code += "input_neg_points_list.append([input_neg_points[" + n + "][0], input_neg_points[" + n + "][1]])" + System.lineSeparator(); - code += "" - + "input_points = np.concatenate(" - + "(np.array(input_points_list).reshape(" + nPoints + ", 2), np.array(input_neg_points_list).reshape(" + nNegPoints + ", 2))" - + ", axis=0)" + System.lineSeparator() - + "input_points = torch.reshape(torch.tensor(input_points), [1, 1, -1, 2])" + System.lineSeparator() - + "input_label = np.array([1] * " + (nPoints + nNegPoints) + ")" + System.lineSeparator() - + "input_label[" + nPoints + ":] -= 1" + System.lineSeparator() - + "input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator() - + "predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() - + " input_points," + System.lineSeparator() - + " input_label," + System.lineSeparator() - + " multimask_output=True," + System.lineSeparator() - + " input_h=input_h," + System.lineSeparator() - + " input_w=input_w," + System.lineSeparator() - + " output_h=input_h," + System.lineSeparator() - + " output_w=input_w,)" + System.lineSeparator() - + "sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator() - + "predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator() - + "predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator() - + "mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() - + "task.update('end predict')" + System.lineSeparator() - + "task.update(str(mask.shape))" + System.lineSeparator() - //+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator() - + "contours_x,contours_y = get_polygons_from_binary_mask(mask, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + "task.update('all contours traced')" + System.lineSeparator() - + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() - + "task.outputs['contours_y'] = contours_y" + System.lineSeparator(); - this.script = code; - } + protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll); - private void processBoxWithSAM(boolean returnAll) { - String code = "" + System.lineSeparator() - + "task.update('start predict')" + System.lineSeparator() - + "input_box = np.array([[input_box[0], input_box[1]], [input_box[2], input_box[3]]])" + System.lineSeparator() - + "input_box = torch.reshape(torch.tensor(input_box), [1, 1, -1, 2])" + System.lineSeparator() - + "input_label = np.array([2,3])" + System.lineSeparator() - + "input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator() - + "predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator() - + " input_box," + System.lineSeparator() - + " input_label," + System.lineSeparator() - + " multimask_output=True," + System.lineSeparator() - + " input_h=input_h," + System.lineSeparator() - + " input_w=input_w," + System.lineSeparator() - + " output_h=input_h," + System.lineSeparator() - + " output_w=input_w,)" + System.lineSeparator() - + "sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator() - + "predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator() - + "predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator() - + "mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator() - + "task.update('end predict')" + System.lineSeparator() - + "task.update(str(mask.shape))" + System.lineSeparator() - //+ "np.save('/home/carlos/git/mask.npy', mask)" + System.lineSeparator() - + "contours_x,contours_y = get_polygons_from_binary_mask(mask, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator() - + "task.update('all contours traced')" + System.lineSeparator() - + "task.outputs['contours_x'] = contours_x" + System.lineSeparator() - + "task.outputs['contours_y'] = contours_y" + System.lineSeparator(); - this.script = code; - } + protected abstract void processBoxWithSAM(boolean returnAll); private static & NativeType> SharedMemoryArray createEfficientSAMInputSHM(final RandomAccessibleInterval inImg) {