Skip to content

Commit

Permalink
convert from polygons to masks
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 25, 2024
1 parent 38abb46 commit 1a300eb
Showing 1 changed file with 42 additions and 55 deletions.
97 changes: 42 additions & 55 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Collectors;

import ai.nets.samj.annotation.Mask;

import java.awt.Polygon;
import java.awt.Rectangle;
import java.io.IOException;
Expand Down Expand Up @@ -67,6 +70,8 @@ public abstract class AbstractSamJ implements AutoCloseable {
public static long MAX_ENCODED_SIDE = MAX_ENCODED_AREA_RS * 3;

protected static long ENCODE_MARGIN = 64;

protected static long MAX_IMG_SIZE = 2048;

/** Essentially, a syntactic-shortcut for a String consumer */
public interface DebugTextPrinter { void printText(String text); }
Expand Down Expand Up @@ -101,6 +106,12 @@ public interface DebugTextPrinter { void printText(String text); }
* The axes are "xyc"
*/
protected long[] targetDims;
/**
* Target dimensions of the image that is going to be encoded after downsampling to send it faster to the other process.
* If a single-channel 2D image is provided, that image is converted into a 3-channel image that EfficientSAM requires.
* The axes are "xyc"
*/
protected long[] targetReescaledDims;
/**
* Coordinates of the vertex of the crop/zoom of hte image of interest that has been encoded.
* It is the closest vertex to the origin.
Expand Down Expand Up @@ -332,14 +343,17 @@ else if (cropSize.length == 3 && cropSize[2] != 3)
throw new IllegalArgumentException("The size of the area that wants to be encoded needs to be defined as [width, height].");
RandomAccessibleInterval<T> crop =
Views.offsetInterval( Cast.unchecked(img), new long[] {encodeCoords[0], encodeCoords[1], 0}, cropSize );

//RandomAccessibleInterval<T> crop = Views.offsetInterval(crop, new long[] {encodeCoords[1], encodeCoords[0], 0}, interValSize);
//RandomAccessibleInterval<T> crop = Views.offsetInterval(Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize);
targetDims = crop.dimensionsAsLongArray();
targetReescaledDims = crop.dimensionsAsLongArray();
createSHMArray(crop);
/*
RandomAccessibleInterval<T> shmCrop2 = Views.subsample(crop,
new long[] {crop.dimension(0) / MAX_IMG_SIZE, crop.dimension(1) / MAX_IMG_SIZE, 1});
createSHMArray(shmCrop2);
*/
}

private List<Polygon> processAndRetrieveContours(HashMap<String, Object> inputs)
private List<Mask> processAndRetrieveContours(HashMap<String, Object> inputs)
throws IOException, RuntimeException, InterruptedException {
Map<String, Object> results = null;
try {
Expand Down Expand Up @@ -370,43 +384,16 @@ else if (task.outputs.get("contours_y") == null)
final List<List<Number>> contours_x_container = (List<List<Number>>)results.get("contours_x");
final Iterator<List<Number>> contours_x = contours_x_container.iterator();
final Iterator<List<Number>> contours_y = ((List<List<Number>>)results.get("contours_y")).iterator();
final List<Polygon> polys = new ArrayList<>(contours_x_container.size());
final Iterator<List<Number>> rles = ((List<List<Number>>)results.get("rle")).iterator();
final List<Mask> masks = new ArrayList<Mask>(contours_x_container.size());
Rectangle cropRect = new Rectangle((int) encodeCoords[0], (int) encodeCoords[1], (int) targetDims[0], (int) targetDims[1]);
while (contours_x.hasNext()) {
int[] xArr = contours_x.next().stream().mapToInt(Number::intValue).toArray();
int[] yArr = contours_y.next().stream().mapToInt(Number::intValue).toArray();
polys.add( new Polygon(xArr, yArr, xArr.length) );
long[] rle = rles.next().stream().mapToLong(Number::longValue).toArray();
masks.add(Mask.build(new Polygon(xArr, yArr, xArr.length), rle, cropRect));
}
return polys;
}

/**
* TODO explain what happens with big images
* Method that uses SAM-like models to create a pseudo-segmentation of the image of interest.
* The objectSize argument is used to create a grid of points that cover the original image
* and that will be used as prompts to the model.
* The object masks with more accuracy will be put together to generate a pseudo-segmentation mask.
*
* Imagine that you have a 512x256 image, with objects of an approximate size of 50 pixels. This method
* will create a grid of ceil(512 / 30) x ceil(256 / 30) = 18 x 9 equally spaced points, then each of
* them will be prompted to the SAM model to see if there is any image in the position of the point.
* Finally, all the masks that have been found will be put together creating a pseudo-mask that
* will probably have not detected eery object in the image.
* Regard, that the smaller the objectSize argument, the more prompts will be evaluated, the more time
* the pseudo-segmentation will take but the more objects will be detected.
*
* @param objectSize
* mean size of the object
* @param returnAll
* whether to return all the ROIS found with one prompt or only the largest one
* @return a pseudo-segmentation mask
* @throws IOException if any of the files needed to run the Python script is missing
* @throws RuntimeException if there is any error running the Python process
* @throws InterruptedException if the process in interrupted
*/
public List<Polygon> noPromptAnnotation(int objectSize, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
// TODO add something
return null;
return masks;
}

/**
Expand All @@ -422,7 +409,7 @@ public List<Polygon> noPromptAnnotation(int objectSize, boolean returnAll)
* @throws RuntimeException if there is any error running the Python process
* @throws InterruptedException if the process in interrupted
*/
public List<Polygon> processPoints(List<int[]> pointsList)
public List<Mask> processPoints(List<int[]> pointsList)
throws IOException, RuntimeException, InterruptedException{
return processPoints(pointsList, true);
}
Expand All @@ -442,7 +429,7 @@ public List<Polygon> processPoints(List<int[]> pointsList)
* @throws RuntimeException if there is any error running the Python process
* @throws InterruptedException if the process in interrupted
*/
public List<Polygon> processPoints(List<int[]> pointsList, boolean returnAll)
public List<Mask> processPoints(List<int[]> pointsList, boolean returnAll)
throws IOException, RuntimeException, InterruptedException{
Rectangle rect = new Rectangle();
rect.x = -1;
Expand All @@ -452,7 +439,7 @@ public List<Polygon> processPoints(List<int[]> pointsList, boolean returnAll)
return processPoints(pointsList, rect, returnAll);
}

public List<Polygon> processPoints(List<int[]> pointsList, Rectangle encodingArea, boolean returnAll)
public List<Mask> processPoints(List<int[]> pointsList, Rectangle encodingArea, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
Objects.requireNonNull(encodingArea, "Second argument cannot be null. Use the method "
+ "'processPoints(List<int[]> pointsList, Rectangle zoomedArea, boolean returnAll)'"
Expand All @@ -476,7 +463,7 @@ public List<Polygon> processPoints(List<int[]> pointsList, Rectangle encodingAre
* @throws RuntimeException if there is any error running the Python process
* @throws InterruptedException if the process in interrupted
*/
public List<Polygon> processPoints(List<int[]> pointsList, List<int[]> pointsNegList)
public List<Mask> processPoints(List<int[]> pointsList, List<int[]> pointsNegList)
throws IOException, RuntimeException, InterruptedException {
Rectangle rect = new Rectangle();
rect.x = (int) this.encodeCoords[0];
Expand All @@ -486,7 +473,7 @@ public List<Polygon> processPoints(List<int[]> pointsList, List<int[]> pointsNeg
return processPoints(pointsList, pointsNegList, rect, true);
}

public List<Polygon> processPoints(List<int[]> pointsList, List<int[]> pointsNegList,
public List<Mask> processPoints(List<int[]> pointsList, List<int[]> pointsNegList,
Rectangle zoomedArea)
throws IOException, RuntimeException, InterruptedException {
return processPoints(pointsList, pointsNegList, zoomedArea, true);
Expand All @@ -510,7 +497,7 @@ public List<Polygon> processPoints(List<int[]> pointsList, List<int[]> pointsNeg
* @throws RuntimeException if there is any error running the Python process
* @throws InterruptedException if the process in interrupted
*/
public List<Polygon> processPoints(List<int[]> pointsList, List<int[]> pointsNegList, boolean returnAll)
public List<Mask> processPoints(List<int[]> pointsList, List<int[]> pointsNegList, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
Rectangle rect = new Rectangle();
rect.x = (int) this.encodeCoords[0];
Expand Down Expand Up @@ -541,7 +528,7 @@ public List<Polygon> processPoints(List<int[]> pointsList, List<int[]> pointsNeg
* @throws RuntimeException if there is any error running the Python process
* @throws InterruptedException if the process in interrupted
*/
public List<Polygon> processPoints(List<int[]> pointsList, List<int[]> pointsNegList,
public List<Mask> processPoints(List<int[]> pointsList, List<int[]> pointsNegList,
Rectangle encodingArea, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
Objects.requireNonNull(encodingArea, "Third argument cannot be null. Use the method "
Expand All @@ -567,7 +554,7 @@ public List<Polygon> processPoints(List<int[]> pointsList, List<int[]> pointsNeg
inputs.put("input_points", pointsList);
inputs.put("input_neg_points", pointsNegList);
printScript(script, "Points and negative points inference");
List<Polygon> polys = processAndRetrieveContours(inputs);
List<Mask> polys = processAndRetrieveContours(inputs);
recalculatePolys(polys, encodeCoords);
debugPrinter.printText("processPoints() obtained " + polys.size() + " polygons");
return polys;
Expand Down Expand Up @@ -598,7 +585,7 @@ private List<int[]> adaptPointPrompts(List<int[]> pointsList) {
* @throws RuntimeException if there is any error running the Python process
* @throws InterruptedException if the process in interrupted
*/
public List<Polygon> processBox(int[] boundingBox)
public List<Mask> processBox(int[] boundingBox)
throws IOException, RuntimeException, InterruptedException {
return processBox(boundingBox, true);
}
Expand All @@ -618,7 +605,7 @@ public List<Polygon> processBox(int[] boundingBox)
* @throws RuntimeException if there is any error running the Python process
* @throws InterruptedException if the process in interrupted
*/
public List<Polygon> processBox(int[] boundingBox, boolean returnAll)
public List<Mask> processBox(int[] boundingBox, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
if (!this.imageSmall || this.encodeCoords[0] != 0 || this.encodeCoords[1] != 0
|| targetDims[0] != img.dimensionsAsLongArray()[0] || targetDims[1] != img.dimensionsAsLongArray()[1]) {
Expand All @@ -637,7 +624,7 @@ public List<Polygon> processBox(int[] boundingBox, boolean returnAll)
HashMap<String, Object> inputs = new HashMap<String, Object>();
inputs.put("input_box", adaptedBoundingBox);
printScript(script, "Rectangle inference");
List<Polygon> polys = processAndRetrieveContours(inputs);
List<Mask> polys = processAndRetrieveContours(inputs);
recalculatePolys(polys, encodeCoords);
debugPrinter.printText("processBox() obtained " + polys.size() + " polygons");
return polys;
Expand All @@ -660,7 +647,7 @@ public List<Polygon> processBox(int[] boundingBox, boolean returnAll)
* @throws InterruptedException if the process in interrupted
*/
public <T extends RealType<T> & NativeType<T>>
List<Polygon> processMask(RandomAccessibleInterval<T> img)
List<Mask> processMask(RandomAccessibleInterval<T> img)
throws IOException, RuntimeException, InterruptedException {
return processMask(img, true);
}
Expand All @@ -683,7 +670,7 @@ List<Polygon> processMask(RandomAccessibleInterval<T> img)
* @throws InterruptedException if the process in interrupted
*/
public <T extends RealType<T> & NativeType<T>>
List<Polygon> processMask(RandomAccessibleInterval<T> img, boolean returnAll)
List<Mask> processMask(RandomAccessibleInterval<T> img, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
long[] dims = img.dimensionsAsLongArray();
if (dims.length == 2 && dims[1] == this.shma.getOriginalShape()[1] && dims[0] == this.shma.getOriginalShape()[0]) {
Expand All @@ -701,12 +688,12 @@ List<Polygon> processMask(RandomAccessibleInterval<T> img, boolean returnAll)
}
}

private List<Polygon> processMask(SharedMemoryArray shmArr, boolean returnAll)
private List<Mask> processMask(SharedMemoryArray shmArr, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
this.script = "";
processMasksWithSam(shmArr, returnAll);
printScript(script, "Pre-computed mask inference");
List<Polygon> polys = processAndRetrieveContours(null);
List<Mask> polys = processAndRetrieveContours(null);
debugPrinter.printText("processMask() obtained " + polys.size() + " polygons");
return polys;
}
Expand Down Expand Up @@ -1050,10 +1037,10 @@ protected long[] calculateEncodingNewCoords(int[] boundingBox, long[] imageSize)
* @param encodeCoords
* position of the crop in the total image
*/
protected void recalculatePolys(List<Polygon> polys, long[] encodeCoords) {
protected void recalculatePolys(List<Mask> polys, long[] encodeCoords) {
polys.stream().forEach(pp -> {
pp.xpoints = Arrays.stream(pp.xpoints).map(x -> x + (int) encodeCoords[0]).toArray();
pp.ypoints = Arrays.stream(pp.ypoints).map(y -> y + (int) encodeCoords[1]).toArray();
pp.getContour().xpoints = Arrays.stream(pp.getContour().xpoints).map(x -> x + (int) encodeCoords[0]).toArray();
pp.getContour().ypoints = Arrays.stream(pp.getContour().ypoints).map(y -> y + (int) encodeCoords[1]).toArray();
});
}

Expand Down

0 comments on commit 1a300eb

Please sign in to comment.