Skip to content

Commit

Permalink
iterate to integrate batches of prompts of any combination
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 30, 2024
1 parent 3000929 commit 55bc76c
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 30 deletions.
53 changes: 39 additions & 14 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;
import net.imglib2.view.Views;
Expand Down Expand Up @@ -166,6 +167,9 @@ public interface DebugTextPrinter { void printText(String text); }

protected abstract void cellSAM(List<int[]> grid, boolean returnAll);

protected abstract <T extends RealType<T> & NativeType<T>> void
processPromptsBatchWithSAM(List<int[]> points, List<Rectangle> rects, RandomAccessibleInterval<T> rai, boolean returnAll);

protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll);

protected abstract void processBoxWithSAM(boolean returnAll);
Expand Down Expand Up @@ -405,28 +409,49 @@ else if (task.outputs.get("rle") == null)
return masks;
}

public List<Mask> processBatchOfPoints(List<long[]> points) {
return processBatchOfPoints(points, true);
public <T extends RealType<T> & NativeType<T>>
List<Mask> processBatchOfPrompts(List<int[]> points, List<Rectangle> rects, RandomAccessibleInterval<T> rai)
throws IOException, RuntimeException, InterruptedException {
return processBatchOfPrompts(points, rects, rai, true);
}

public List<Mask> processBatchOfPoints(List<long[]> pointsList, boolean returnAll)
public <T extends RealType<T> & NativeType<T>>
List<Mask> processBatchOfPrompts(List<int[]> pointsList, List<Rectangle> rects, RandomAccessibleInterval<T> rai, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
if (pointsList == null || pointsList.size() == 0)
checkPrompts(pointsList, rects, rai);
if ((pointsList == null || pointsList.size() == 0) && (rects == null || rects.size() == 0) && (rai == null))
return new ArrayList<Mask>();
// TODO add logic to reeencode
// TODO the idea is that the


pointsList = adaptPointPrompts(pointsList);
pointsNegList = adaptPointPrompts(pointsNegList);
this.script = "";
processPointsWithSAM(pointsList.size(), pointsNegList.size(), returnAll);
HashMap<String, Object> inputs = new HashMap<String, Object>();
inputs.put("input_points", pointsList);
inputs.put("input_neg_points", pointsNegList);
processPromptsBatchWithSAM(pointsList, null, null, returnAll);
printScript(script, "Points and negative points inference");
List<Mask> polys = processAndRetrieveContours(inputs);
List<Mask> polys = processAndRetrieveContours(null);
recalculatePolys(polys, encodeCoords);
debugPrinter.printText("processPoints() obtained " + polys.size() + " polygons");
return polys;
}

private <T extends RealType<T> & NativeType<T>>
void checkPrompts(List<int[]> pointsList, List<Rectangle> rects, RandomAccessibleInterval<T> rai) {
if ((pointsList == null || pointsList.size() == 0)
&& (rects == null || rects.size() == 0)
&& !(rai.getType() instanceof IntegerType))
throw new IllegalArgumentException("The mask provided should be of any integer type.");
else if ((pointsList == null || pointsList.size() == 0)
&& (rects == null || rects.size() == 0)
&& !(rai.getType() instanceof IntegerType)) {
throw new IllegalArgumentException("The mask provided should be of the same size as the image of interest.");
}
}

public List<Mask> processBatchOfPoints(List<int[]> points) throws IOException, RuntimeException, InterruptedException {
return processBatchOfPoints(points, true);
}

public List<Mask> processBatchOfPoints(List<int[]> pointsList, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
List<Mask> polys = processBatchOfPrompts(pointsList, null, null, returnAll);
debugPrinter.printText("processBatchOfPoints() obtained " + polys.size() + " polygons");
return polys;
}

Expand Down
8 changes: 8 additions & 0 deletions src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
package ai.nets.samj.models;

import java.awt.Rectangle;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -442,4 +443,11 @@ public String selectEncodingScript(String encodingName) {
public String deleteEncodingScript(String encodingName) {
return "del encodings_map['" + encodingName + "']";
}

@Override
protected <T extends RealType<T> & NativeType<T>> void processPromptsBatchWithSAM(List<int[]> points,
List<Rectangle> rects, RandomAccessibleInterval<T> rai, boolean returnAll) {
// TODO Auto-generated method stub

}
}
12 changes: 8 additions & 4 deletions src/main/java/ai/nets/samj/models/EfficientViTSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import ai.nets.samj.install.EfficientViTSamEnvManager;
import ai.nets.samj.install.SamEnvManagerAbstract;

import java.awt.Rectangle;
import java.io.File;
import java.io.IOException;

Expand Down Expand Up @@ -403,11 +404,7 @@ protected void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnA
+ " box=None,)" + System.lineSeparator()
+ "task.update('end predict')" + System.lineSeparator()
+ "task.update(str(mask.shape))" + System.lineSeparator()
// TODO remove + "import matplotlib.pyplot as plt" + System.lineSeparator()
// TODO remove + "plt.imsave('/tmp/aa.jpg', mask[0], cmap='gray')" + System.lineSeparator()
+ (this.isIJROIManager ? "mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator()
//+ (this.isIJROIManager ? "mask[0, :, 1:] += mask[0, :, :-1]" : "") + System.lineSeparator()
//+ "np.save('/home/carlos/git/aa.npy', mask)" + System.lineSeparator()
+ "contours_x, contours_y, rle_masks = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ "task.update('all contours traced')" + System.lineSeparator()
+ "task.outputs['contours_x'] = contours_x" + System.lineSeparator()
Expand Down Expand Up @@ -511,4 +508,11 @@ public String selectEncodingScript(String encodingName) {
public String deleteEncodingScript(String encodingName) {
return "del encodings_map['" + encodingName + "']";
}

@Override
protected <T extends RealType<T> & NativeType<T>> void processPromptsBatchWithSAM(List<int[]> points,
List<Rectangle> rects, RandomAccessibleInterval<T> rai, boolean returnAll) {
// TODO Auto-generated method stub

}
}
31 changes: 19 additions & 12 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import ai.nets.samj.install.Sam2EnvManager;
import ai.nets.samj.install.SamEnvManagerAbstract;

import java.awt.Rectangle;
import java.io.IOException;

import io.bioimage.modelrunner.apposed.appose.Environment;
Expand Down Expand Up @@ -327,6 +328,12 @@ protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll)
+ "point_prompts = []" + System.lineSeparator()
+ "point_labels = []" + System.lineSeparator()
+ "labeled_array, num_features = label(mask)" + System.lineSeparator()
+ "contours_x = []" + System.lineSeparator()
+ "contours_y = []" + System.lineSeparator()
+ "rle_masks = []" + System.lineSeparator()
// TODO right now is geetting the mask after each prompt
// TODO test processing first every prompt and then getting the masks
+ "" + System.lineSeparator()
+ "for n_feat in range(num_features):" + System.lineSeparator()
+ " inds = np.where(labeled_array == n_feat)" + System.lineSeparator()
+ " n_points = np.min([3, inds[0].shape[0]])" + System.lineSeparator()
Expand All @@ -335,21 +342,18 @@ protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll)
+ " point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
+ " point_labels += [n_feat]" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=point_prompts," + System.lineSeparator()
+ " point_labels=point_labels," + System.lineSeparator()
+ " multimask_output=False," + System.lineSeparator()
+ " box=None,)" + System.lineSeparator()
+ "contours_x = []" + System.lineSeparator()
+ "contours_y = []" + System.lineSeparator()
+ "rle_masks = []" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "for b in range(num_features):" + System.lineSeparator()
+ " mm = mask[b]" + System.lineSeparator()
+ (this.isIJROIManager ? " mm += mm[:-1, :-1]" : "") + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mm, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
// TODO + "for b in range(num_features):" + System.lineSeparator()
// TODO + " mm = mask[b]" + System.lineSeparator()
+ (this.isIJROIManager ? " mask += mask[0, :-1, :-1]" : "") + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " contours_x += c_x" + System.lineSeparator()
+ " contours_y += c_Y" + System.lineSeparator()
+ " rle_masks += r_m" + System.lineSeparator()
Expand Down Expand Up @@ -393,11 +397,7 @@ protected void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnA
+ " box=None,)" + System.lineSeparator()
+ "task.update('end predict')" + System.lineSeparator()
+ "task.update(str(mask.shape))" + System.lineSeparator()
// TODO remove + "import matplotlib.pyplot as plt" + System.lineSeparator()
// TODO remove + "plt.imsave('/tmp/aa.jpg', mask[0], cmap='gray')" + System.lineSeparator()
+ (this.isIJROIManager ? "mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator()
//+ (this.isIJROIManager ? "mask[0, :, 1:] += mask[0, :, :-1]" : "") + System.lineSeparator()
//+ "np.save('/home/carlos/git/aa.npy', mask)" + System.lineSeparator()
+ "contours_x, contours_y, rle_masks = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ "task.update('all contours traced')" + System.lineSeparator()
+ "task.outputs['contours_x'] = contours_x" + System.lineSeparator()
Expand Down Expand Up @@ -502,4 +502,11 @@ public String selectEncodingScript(String encodingName) {
public String deleteEncodingScript(String encodingName) {
return "del encodings_map['" + encodingName + "']";
}

@Override
protected <T extends RealType<T> & NativeType<T>> void processPromptsBatchWithSAM(List<int[]> points,
List<Rectangle> rects, RandomAccessibleInterval<T> rai, boolean returnAll) {
// TODO Auto-generated method stub

}
}

0 comments on commit 55bc76c

Please sign in to comment.