Skip to content

Commit

Permalink
keep iterating to allow batches of prompts
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 30, 2024
1 parent 55bc76c commit 027ad06
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 218 deletions.
11 changes: 8 additions & 3 deletions src/main/java/ai/nets/samj/communication/model/SAMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,14 @@ public void setReturnOnlyBiggest(boolean onlyBiggest) {
this.onlyBiggest = onlyBiggest;
}

public List<Mask> processBatchOfPoints(List<long[]> points) throws IOException, RuntimeException, InterruptedException {
//return samj.processBathcOfPoints(points, !onlyBiggest);
return null;
public List<Mask> processBatchOfPoints(List<int[]> points) throws IOException, RuntimeException, InterruptedException {
return samj.processBatchOfPoints(points, !onlyBiggest);
}

public <T extends RealType<T> & NativeType<T>>
List<Mask> processBatchOfPrompts(List<int[]> points, List<Rectangle> rects, RandomAccessibleInterval<T> rai)
throws IOException, RuntimeException, InterruptedException {
return samj.processBatchOfPrompts(points, rects, rai, !onlyBiggest);
}

/**
Expand Down
73 changes: 34 additions & 39 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ public interface DebugTextPrinter { void printText(String text); }

protected abstract void cellSAM(List<int[]> grid, boolean returnAll);

protected abstract <T extends RealType<T> & NativeType<T>> void
processPromptsBatchWithSAM(List<int[]> points, List<Rectangle> rects, RandomAccessibleInterval<T> rai, boolean returnAll);
protected abstract void
processPromptsBatchWithSAM(List<int[]> points, List<Rectangle> rects, SharedMemoryArray shmArr, boolean returnAll);

protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll);

Expand All @@ -178,8 +178,6 @@ public interface DebugTextPrinter { void printText(String text); }

protected abstract <T extends RealType<T> & NativeType<T>> void createEncodeImageScript();

abstract protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll);

protected abstract <T extends RealType<T> & NativeType<T>> void createSHMArray(RandomAccessibleInterval<T> imShared);

@Override
Expand Down Expand Up @@ -418,29 +416,49 @@ List<Mask> processBatchOfPrompts(List<int[]> points, List<Rectangle> rects, Rand
public <T extends RealType<T> & NativeType<T>>
List<Mask> processBatchOfPrompts(List<int[]> pointsList, List<Rectangle> rects, RandomAccessibleInterval<T> rai, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
checkPrompts(pointsList, rects, rai);
if ((pointsList == null || pointsList.size() == 0) && (rects == null || rects.size() == 0) && (rai == null))
return new ArrayList<Mask>();
checkPrompts(pointsList, rects, rai);

// TODO adapt to reencoding for big images, ideally it should process points close together together
pointsList = adaptPointPrompts(pointsList);
// TODO adapt rect prompts
this.script = "";
processPromptsBatchWithSAM(pointsList, null, null, returnAll);
printScript(script, "Points and negative points inference");
List<Mask> polys = processAndRetrieveContours(null);
recalculatePolys(polys, encodeCoords);
return polys;
SharedMemoryArray maskShma = null;
if (rai != null)
maskShma = SharedMemoryArray.createSHMAFromRAI(rai, false, false);

try {
processPromptsBatchWithSAM(pointsList, rects, maskShma, returnAll);
printScript(script, "Batch of prompts inference");
List<Mask> polys = processAndRetrieveContours(null);
recalculatePolys(polys, encodeCoords);
return polys;
} catch (IOException | RuntimeException | InterruptedException ex) {
maskShma.close();
throw ex;
}
}

private <T extends RealType<T> & NativeType<T>>
void checkPrompts(List<int[]> pointsList, List<Rectangle> rects, RandomAccessibleInterval<T> rai) {
long[] dims;
if ((pointsList == null || pointsList.size() == 0)
&& (rects == null || rects.size() == 0)
&& !(rai.getType() instanceof IntegerType))
&& rai != null && !(rai.getType() instanceof IntegerType)) {
throw new IllegalArgumentException("The mask provided should be of any integer type.");
else if ((pointsList == null || pointsList.size() == 0)
} else if ((pointsList == null || pointsList.size() == 0)
&& (rects == null || rects.size() == 0)
&& !(rai.getType() instanceof IntegerType)) {
throw new IllegalArgumentException("The mask provided should be of the same size as the image of interest.");
&& rai != null) {
dims = rai.dimensionsAsLongArray();
if ((dims.length == 2 || (dims.length == 3 && dims[2] == 1))
&& dims[1] == this.shma.getOriginalShape()[0] && dims[0] == this.shma.getOriginalShape()[1]) {
rai = Views.permute(rai, 0, 1);
} else if (dims[0] != this.shma.getOriginalShape()[0] && dims[1] != this.shma.getOriginalShape()[1]
|| (dims.length == 3 && dims[2] != 1) || dims.length > 3) {
throw new IllegalArgumentException("The provided mask should be a 2d image with just one channel of width "
+ this.shma.getOriginalShape()[1] + " and height " + this.shma.getOriginalShape()[0]);
}
}
}

Expand Down Expand Up @@ -730,32 +748,9 @@ List<Mask> processMask(RandomAccessibleInterval<T> img)
* @throws InterruptedException if the process in interrupted
*/
public <T extends RealType<T> & NativeType<T>>
List<Mask> processMask(RandomAccessibleInterval<T> img, boolean returnAll)
List<Mask> processMask(RandomAccessibleInterval<T> rai, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
long[] dims = img.dimensionsAsLongArray();
if ((dims.length == 2 || (dims.length == 3 && dims[2] == 1))
&& dims[1] == this.shma.getOriginalShape()[0] && dims[0] == this.shma.getOriginalShape()[1]) {
img = Views.permute(img, 0, 1);
} else if (dims[0] != this.shma.getOriginalShape()[0] && dims[1] != this.shma.getOriginalShape()[1]
|| (dims.length == 3 && dims[2] != 1) || dims.length > 3) {
throw new IllegalArgumentException("The provided mask should be a 2d image with just one channel of width "
+ this.shma.getOriginalShape()[1] + " and height " + this.shma.getOriginalShape()[0]);
}
SharedMemoryArray maskShma = SharedMemoryArray.createSHMAFromRAI(img, false, false);
try {
return processMask(maskShma, returnAll);
} catch (IOException | RuntimeException | InterruptedException ex) {
maskShma.close();
throw ex;
}
}

private List<Mask> processMask(SharedMemoryArray shmArr, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
this.script = "";
processMasksWithSam(shmArr, returnAll);
printScript(script, "Pre-computed mask inference");
List<Mask> polys = processAndRetrieveContours(null);
List<Mask> polys = processBatchOfPrompts(null, null, rai, returnAll);
debugPrinter.printText("processMask() obtained " + polys.size() + " polygons");
return polys;
}
Expand Down
118 changes: 56 additions & 62 deletions src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -226,64 +226,6 @@ protected void createEncodeImageScript() {
+ "_ = predictor.get_image_embeddings(im[None, ...])" + System.lineSeparator();
}

@Override
protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll) {
String code = "";
code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator();
code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape([";
for (long l : shmArr.getOriginalShape())
code += l + ",";
code += "])" + System.lineSeparator();
code += "different_mask_vals = np.unique(mask)" + System.lineSeparator();
//code += "print(different_mask_vals)" + System.lineSeparator();
code += "cont_x = []" + System.lineSeparator();
code += "cont_y = []" + System.lineSeparator();
code += "rle_masks = []" + System.lineSeparator();
code += "for val in different_mask_vals:" + System.lineSeparator()
+ " if val < 1:" + System.lineSeparator()
+ " continue" + System.lineSeparator()
+ " locations = np.where(mask == val)" + System.lineSeparator()
+ " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator()
+ " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator()
+ " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator()
+ " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator()
+ " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator()
+ " input_points_pos[:, 0] = locations[0]" + System.lineSeparator()
+ " input_points_pos[:, 1] = locations[1]" + System.lineSeparator()
+ " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator()
+ " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator()
+ " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator()
+ " input_label = np.concatenate((input_labels_pos, input_labels_neg * 0), axis=0)" + System.lineSeparator()
+ " input_points = torch.reshape(torch.tensor(input_points), [1, 1, -1, 2])" + System.lineSeparator()
+ " input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator()
+ " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator()
+ " input_points," + System.lineSeparator()
+ " input_label," + System.lineSeparator()
+ " multimask_output=True," + System.lineSeparator()
+ " input_h=input_h," + System.lineSeparator()
+ " input_w=input_w," + System.lineSeparator()
+ " output_h=input_h," + System.lineSeparator()
+ " output_w=input_w,)" + System.lineSeparator()
//+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator()
+ " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator()
+ " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator()
+ " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator()
+ " mask_val = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator()
+ (this.isIJROIManager ? "mask_val[1:, 1:] += mask_val[:-1, :-1]" : "") + System.lineSeparator()
+ " cont_x_val,cont_y_val,rle_val = get_polygons_from_binary_mask(mask_val, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " cont_x += cont_x_val" + System.lineSeparator()
+ " cont_y += cont_y_val" + System.lineSeparator()
+ " rle_masks += rle_val" + System.lineSeparator()
+ "task.update('all contours traced')" + System.lineSeparator()
+ "task.outputs['contours_x'] = cont_x" + System.lineSeparator()
+ "task.outputs['contours_y'] = cont_y" + System.lineSeparator()
+ "task.outputs['rle'] = rle_masks" + System.lineSeparator();
code += "mask = 0" + System.lineSeparator();
code += "shm_mask.close()" + System.lineSeparator();
code += "shm_mask.unlink()" + System.lineSeparator();
this.script = code;
}

@Override
protected void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll) {
String code = "" + System.lineSeparator()
Expand Down Expand Up @@ -445,9 +387,61 @@ public String deleteEncodingScript(String encodingName) {
}

@Override
protected <T extends RealType<T> & NativeType<T>> void processPromptsBatchWithSAM(List<int[]> points,
List<Rectangle> rects, RandomAccessibleInterval<T> rai, boolean returnAll) {
// TODO Auto-generated method stub

protected void processPromptsBatchWithSAM(List<int[]> points, List<Rectangle> rects, SharedMemoryArray shmArr,
boolean returnAll) {
String code = "";
code += "shm_mask = shared_memory.SharedMemory(name='" + shmArr.getNameForPython() + "')" + System.lineSeparator();
code += "mask = np.frombuffer(buffer=shm_mask.buf, dtype='" + shmArr.getOriginalDataType() + "').reshape([";
for (long l : shmArr.getOriginalShape())
code += l + ",";
code += "])" + System.lineSeparator();
code += "different_mask_vals = np.unique(mask)" + System.lineSeparator();
//code += "print(different_mask_vals)" + System.lineSeparator();
code += "cont_x = []" + System.lineSeparator();
code += "cont_y = []" + System.lineSeparator();
code += "rle_masks = []" + System.lineSeparator();
code += "for val in different_mask_vals:" + System.lineSeparator()
+ " if val < 1:" + System.lineSeparator()
+ " continue" + System.lineSeparator()
+ " locations = np.where(mask == val)" + System.lineSeparator()
+ " input_points_pos = np.zeros((locations[0].shape[0], 2))" + System.lineSeparator()
+ " input_labels_pos = np.ones((locations[0].shape[0]))" + System.lineSeparator()
+ " locations_neg = np.where((mask != val) & (mask != 0))" + System.lineSeparator()
+ " input_points_neg = np.zeros((locations_neg[0].shape[0], 2))" + System.lineSeparator()
+ " input_labels_neg = np.zeros((locations_neg[0].shape[0]))" + System.lineSeparator()
+ " input_points_pos[:, 0] = locations[0]" + System.lineSeparator()
+ " input_points_pos[:, 1] = locations[1]" + System.lineSeparator()
+ " input_points_neg[:, 0] = locations_neg[0]" + System.lineSeparator()
+ " input_points_neg[:, 1] = locations_neg[1]" + System.lineSeparator()
+ " input_points = np.concatenate((input_points_pos.reshape(-1, 2), input_points_neg.reshape(-1, 2)), axis=0)" + System.lineSeparator()
+ " input_label = np.concatenate((input_labels_pos, input_labels_neg * 0), axis=0)" + System.lineSeparator()
+ " input_points = torch.reshape(torch.tensor(input_points), [1, 1, -1, 2])" + System.lineSeparator()
+ " input_label = torch.reshape(torch.tensor(input_label), [1, 1, -1])" + System.lineSeparator()
+ " predicted_logits, predicted_iou = predictor.predict_masks(predictor.encoded_images," + System.lineSeparator()
+ " input_points," + System.lineSeparator()
+ " input_label," + System.lineSeparator()
+ " multimask_output=True," + System.lineSeparator()
+ " input_h=input_h," + System.lineSeparator()
+ " input_w=input_w," + System.lineSeparator()
+ " output_h=input_h," + System.lineSeparator()
+ " output_w=input_w,)" + System.lineSeparator()
//+ "np.save('/temp/aa.npy', mask)" + System.lineSeparator()
+ " sorted_ids = torch.argsort(predicted_iou, dim=-1, descending=True)" + System.lineSeparator()
+ " predicted_iou = torch.take_along_dim(predicted_iou, sorted_ids, dim=2)" + System.lineSeparator()
+ " predicted_logits = torch.take_along_dim(predicted_logits, sorted_ids[..., None, None], dim=2)" + System.lineSeparator()
+ " mask_val = torch.ge(predicted_logits[0, 0, 0, :, :], 0).cpu().detach().numpy()" + System.lineSeparator()
+ (this.isIJROIManager ? "mask_val[1:, 1:] += mask_val[:-1, :-1]" : "") + System.lineSeparator()
+ " cont_x_val,cont_y_val,rle_val = get_polygons_from_binary_mask(mask_val, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " cont_x += cont_x_val" + System.lineSeparator()
+ " cont_y += cont_y_val" + System.lineSeparator()
+ " rle_masks += rle_val" + System.lineSeparator()
+ "task.update('all contours traced')" + System.lineSeparator()
+ "task.outputs['contours_x'] = cont_x" + System.lineSeparator()
+ "task.outputs['contours_y'] = cont_y" + System.lineSeparator()
+ "task.outputs['rle'] = rle_masks" + System.lineSeparator();
code += "mask = 0" + System.lineSeparator();
code += "shm_mask.close()" + System.lineSeparator();
code += "shm_mask.unlink()" + System.lineSeparator();
this.script = code;
}
}
Loading

0 comments on commit 027ad06

Please sign in to comment.