Skip to content

Commit

Permalink
create more abstract methods
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Apr 22, 2024
1 parent d20b542 commit af61255
Showing 1 changed file with 38 additions and 123 deletions.
161 changes: 38 additions & 123 deletions src/main/java/ai/nets/samj/AbstractSamJ2.java
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,40 @@ void updateImage(RandomAccessibleInterval<T> rai) throws IOException, RuntimeExc
* @throws RuntimeException if there is any error running the Python code
* @throws InterruptedException if the process is interrupted
*/
public abstract <T extends RealType<T> & NativeType<T>>
void addImage(RandomAccessibleInterval<T> rai);
private <T extends RealType<T> & NativeType<T>>
void addImage(RandomAccessibleInterval<T> rai)
throws IOException, RuntimeException, InterruptedException {
if (rai.dimensionsAsLongArray()[0] * rai.dimensionsAsLongArray()[1] > MAX_ENCODED_AREA_RS * MAX_ENCODED_AREA_RS
|| rai.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || rai.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) {
this.targetDims = new long[] {0, 0, 0};
this.img = rai;
return;
}
this.script = "";
sendImgLib2AsNp(rai);
createEncodeImageScript();
try {
printScript(script, "Creation of initial embeddings");
Task task = python.task(script);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException();
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException();
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException();
this.shma.close();
} catch (IOException | InterruptedException | RuntimeException e) {
try {
this.shma.close();
} catch (IOException e1) {
throw new IOException(e.toString() + System.lineSeparator() + e1.toString());
}
throw e;
}
}

protected abstract void createEncodeImageScript();

private void reencodeCrop() throws IOException, InterruptedException, RuntimeException {
reencodeCrop(null);
Expand All @@ -166,9 +198,7 @@ private void reencodeCrop() throws IOException, InterruptedException, RuntimeExc
private void reencodeCrop(long[] cropSize) throws IOException, InterruptedException, RuntimeException {
this.script = "";
sendCropAsNp(cropSize);
this.script += ""
+ "task.update(str(im.shape))" + System.lineSeparator()
+ "aa = predictor.get_image_embeddings(im[None, ...])";
createEncodeImageScript();
try {
printScript(script, "Creation of the cropped embeddings");
Task task = python.task(script);
Expand Down Expand Up @@ -299,58 +329,7 @@ private List<Polygon> processMask(SharedMemoryArray shmArr, boolean returnAll)
return polys;
}

private void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) {
String code = "";
code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator();
code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape([";
for (long l : shmArr.getOriginalShape())
code += l + ",";
code += "])" + System.lineSeparator();
code += "different_mask_vals = np.unique(mask)" + System.lineSeparator();
//code += "print(different_mask_vals)" + System.lineSeparator();
code += "cont_x = []" + System.lineSeparator();
code += "cont_y = []" + System.lineSeparator();
code += "for val in different_mask_vals:" + System.lineSeparator()
+ " if val < 1:" + System.lineSeparator()
+ " continue" + System.lineSeparator()
+ " locations = np.where(mask == val)" + System.lineSeparator()
+ " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator()
+ " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator()
+ " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator()
+ " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator()
+ " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator()
+ " input_points_pos[:, 0] = locations[0]" + System.lineSeparator()
+ " input_points_pos[:, 1] = locations[1]" + System.lineSeparator()
+ " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator()
+ " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator()
+ " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator()
+ " input_label = np.concatenate((input_labels_pos, input_labels_neg * 0), axis=0)" + System.lineSeparator()
+ " input_points = torch.reshape(torch.tensor(input_points), [1, 1, -1, 2])" + System.lineSeparator()
+ " input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator()
+ " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator()
+ " input_points," + System.lineSeparator()
+ " input_label," + System.lineSeparator()
+ " multimask_output=True," + System.lineSeparator()
+ " input_h=input_h," + System.lineSeparator()
+ " input_w=input_w," + System.lineSeparator()
+ " output_h=input_h," + System.lineSeparator()
+ " output_w=input_w,)" + System.lineSeparator()
//+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator()
+ " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator()
+ " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator()
+ " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator()
+ " mask_val = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator()
+ " cont_x_val,cont_y_val = get_polygons_from_binary_mask(mask_val, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " cont_x += cont_x_val" + System.lineSeparator()
+ " cont_y += cont_y_val" + System.lineSeparator()
+ "task.update('all contours traced')" + System.lineSeparator()
+ "task.outputs['contours_x'] = cont_x" + System.lineSeparator()
+ "task.outputs['contours_y'] = cont_y" + System.lineSeparator();
code += "mask = 0" + System.lineSeparator();
code += "shm_mask.close()" + System.lineSeparator();
code += "shm_mask.unlink()" + System.lineSeparator();
this.script = code;
}
abstract protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll);

/**
* Method used that runs EfficientSAM using a list of points as the prompt. This method runs
Expand Down Expand Up @@ -829,73 +808,9 @@ private <T extends RealType<T> & NativeType<T>> void sendCropAsNp(long[] cropSiz
this.script += code;
}

private void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll) {
String code = "" + System.lineSeparator()
+ "task.update('start predict')" + System.lineSeparator()
+ "input_points_list = []" + System.lineSeparator()
+ "input_neg_points_list = []" + System.lineSeparator();
for (int n = 0; n < nPoints; n ++)
code += "input_points_list.append([input_points[" + n + "][0], input_points[" + n + "][1]])" + System.lineSeparator();
for (int n = 0; n < nNegPoints; n ++)
code += "input_neg_points_list.append([input_neg_points[" + n + "][0], input_neg_points[" + n + "][1]])" + System.lineSeparator();
code += ""
+ "input_points = np.concatenate("
+ "(np.array(input_points_list).reshape(" + nPoints + ", 2), np.array(input_neg_points_list).reshape(" + nNegPoints + ", 2))"
+ ", axis=0)" + System.lineSeparator()
+ "input_points = torch.reshape(torch.tensor(input_points), [1, 1, -1, 2])" + System.lineSeparator()
+ "input_label = np.array([1] * " + (nPoints + nNegPoints) + ")" + System.lineSeparator()
+ "input_label[" + nPoints + ":] -= 1" + System.lineSeparator()
+ "input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator()
+ "predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator()
+ " input_points," + System.lineSeparator()
+ " input_label," + System.lineSeparator()
+ " multimask_output=True," + System.lineSeparator()
+ " input_h=input_h," + System.lineSeparator()
+ " input_w=input_w," + System.lineSeparator()
+ " output_h=input_h," + System.lineSeparator()
+ " output_w=input_w,)" + System.lineSeparator()
+ "sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator()
+ "predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator()
+ "predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator()
+ "mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator()
+ "task.update('end predict')" + System.lineSeparator()
+ "task.update(str(mask.shape))" + System.lineSeparator()
//+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator()
+ "contours_x,contours_y = get_polygons_from_binary_mask(mask, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ "task.update('all contours traced')" + System.lineSeparator()
+ "task.outputs['contours_x'] = contours_x" + System.lineSeparator()
+ "task.outputs['contours_y'] = contours_y" + System.lineSeparator();
this.script = code;
}
protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll);

private void processBoxWithSAM(boolean returnAll) {
String code = "" + System.lineSeparator()
+ "task.update('start predict')" + System.lineSeparator()
+ "input_box = np.array([[input_box[0], input_box[1]], [input_box[2], input_box[3]]])" + System.lineSeparator()
+ "input_box = torch.reshape(torch.tensor(input_box), [1, 1, -1, 2])" + System.lineSeparator()
+ "input_label = np.array([2,3])" + System.lineSeparator()
+ "input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator()
+ "predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator()
+ " input_box," + System.lineSeparator()
+ " input_label," + System.lineSeparator()
+ " multimask_output=True," + System.lineSeparator()
+ " input_h=input_h," + System.lineSeparator()
+ " input_w=input_w," + System.lineSeparator()
+ " output_h=input_h," + System.lineSeparator()
+ " output_w=input_w,)" + System.lineSeparator()
+ "sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator()
+ "predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator()
+ "predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator()
+ "mask = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator()
+ "task.update('end predict')" + System.lineSeparator()
+ "task.update(str(mask.shape))" + System.lineSeparator()
//+ "np.save('/home/carlos/git/mask.npy', mask)" + System.lineSeparator()
+ "contours_x,contours_y = get_polygons_from_binary_mask(mask, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ "task.update('all contours traced')" + System.lineSeparator()
+ "task.outputs['contours_x'] = contours_x" + System.lineSeparator()
+ "task.outputs['contours_y'] = contours_y" + System.lineSeparator();
this.script = code;
}
protected abstract void processBoxWithSAM(boolean returnAll);

private static <T extends RealType<T> & NativeType<T>>
SharedMemoryArray createEfficientSAMInputSHM(final RandomAccessibleInterval<T> inImg) {
Expand Down

0 comments on commit af61255

Please sign in to comment.