Skip to content

Commit

Permalink
downsample big images to send them more efficiently
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 30, 2024
1 parent 6d49168 commit dcc7219
Showing 1 changed file with 18 additions and 79 deletions.
97 changes: 18 additions & 79 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public abstract class AbstractSamJ implements AutoCloseable {

protected static long ENCODE_MARGIN = 64;

protected static long MAX_IMG_SIZE = 2048;
protected static int MAX_IMG_SIZE = 2048;

/** Essentially, a syntactic-shortcut for a String consumer */
public interface DebugTextPrinter { void printText(String text); }
Expand Down Expand Up @@ -127,7 +127,7 @@ public interface DebugTextPrinter { void printText(String text); }
* the object, thus
* The axes are "xyc"
*/
protected double[] scale;
protected int scale;
/**
* TODO this should be false by default, but at the moment IJ is the only consumer
* IMPORTANT (ONLY FOR IMAGEJ ROI MANAGER)
Expand Down Expand Up @@ -344,13 +344,13 @@ else if (cropSize.length == 3 && cropSize[2] != 3)
RandomAccessibleInterval<T> crop =
Views.offsetInterval( Cast.unchecked(img), new long[] {encodeCoords[0], encodeCoords[1], 0}, cropSize );
targetDims = crop.dimensionsAsLongArray();

scale = (int) (Math.min(targetDims[0], targetDims[1]) / MAX_IMG_SIZE);
RandomAccessibleInterval<T> subsampledCrop = Views.subsample(crop,
new long[] {scale, scale, 1});
targetReescaledDims = crop.dimensionsAsLongArray();
createSHMArray(crop);
/*
RandomAccessibleInterval<T> shmCrop2 = Views.subsample(crop,
new long[] {crop.dimension(0) / MAX_IMG_SIZE, crop.dimension(1) / MAX_IMG_SIZE, 1});
createSHMArray(shmCrop2);
*/
createSHMArray(subsampledCrop);

}

private List<Mask> processAndRetrieveContours(HashMap<String, Object> inputs)
Expand Down Expand Up @@ -565,8 +565,8 @@ public List<Mask> processPoints(List<int[]> pointsList, List<int[]> pointsNegLis
private List<int[]> adaptPointPrompts(List<int[]> pointsList) {
pointsList = pointsList.stream().map(pp -> {
int[] newPoint = new int[2];
newPoint[0] = (int) (pp[0] - this.encodeCoords[0]);
newPoint[1] = (int) (pp[1] - this.encodeCoords[1]);
newPoint[0] = (int) ((pp[0] - this.encodeCoords[0]) / scale);
newPoint[1] = (int) ((pp[1] - this.encodeCoords[1]) / scale);
return newPoint;
}).collect(Collectors.toList());
return pointsList;
Expand Down Expand Up @@ -619,8 +619,9 @@ public List<Mask> processBox(int[] boundingBox, boolean returnAll)
reencodeCrop();
}
}
int[] adaptedBoundingBox = new int[] {(int) (boundingBox[0] - encodeCoords[0]), (int) (boundingBox[1] - encodeCoords[1]),
(int) (boundingBox[2] - encodeCoords[0]), (int) (boundingBox[3] - encodeCoords[1])};;
int[] adaptedBoundingBox = new int[] {(int) ((boundingBox[0] - encodeCoords[0]) / scale),
(int) ((boundingBox[1] - encodeCoords[1]) / scale),
(int) ((boundingBox[2] - encodeCoords[0]) / scale), (int) ((boundingBox[3] - encodeCoords[1]) / scale)};;
this.script = "";
processBoxWithSAM(returnAll);
HashMap<String, Object> inputs = new HashMap<String, Object>();
Expand Down Expand Up @@ -806,69 +807,6 @@ private static boolean rectContainsRect(Rectangle outer, Rectangle inner) {
return true;
return false;
}

/**
* TODO remove
* @param pointsList
* @param pointsNegList
* @param rect
* @throws IOException
* @throws InterruptedException
* @throws RuntimeException
*/
private void evaluateReencodingNeeded2(List<int[]> pointsList, List<int[]> pointsNegList, Rectangle rect)
throws IOException, InterruptedException, RuntimeException {
Rectangle alreadyEncoded = getCurrentlyEncodedArea();
Rectangle neededArea = getApproximateAreaNeeded(pointsList, pointsNegList, rect);
if (rect.equals(alreadyEncoded)) neededArea = getApproximateAreaNeeded(pointsList, pointsNegList);
ArrayList<int[]> notInRect = getPointsNotInRect(pointsList, pointsNegList, rect);
if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y
&& alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x
&& alreadyEncoded.height + alreadyEncoded.y >= rect.height + rect.y
&& alreadyEncoded.width * 0.7 < rect.width && alreadyEncoded.height * 0.7 < rect.height
&& notInRect.size() == 0 && alreadyEncoded.contains(neededArea)) {
return;
} else if (alreadyEncoded.x <= rect.x && alreadyEncoded.y <= rect.y
&& alreadyEncoded.width + alreadyEncoded.x >= rect.width + rect.x
&& alreadyEncoded.height + alreadyEncoded.y >= rect.height + rect.y
&& (alreadyEncoded.width * 0.7 > rect.width || alreadyEncoded.height * 0.7 > rect.height)
&& rect.contains(neededArea)) {
this.encodeCoords = new long[] {(long) Math.max(0, rect.x - rect.width * 0.1),
(long) Math.max(0, rect.y - rect.height * 0.1)};
long[] imgDims = this.img.dimensionsAsLongArray();
long width = (long) Math.min(rect.width * 1.2, imgDims[0] - encodeCoords[0]);
long height = (long) Math.min(rect.height * 1.2, imgDims[1] - encodeCoords[1]);
if (alreadyEncoded.x == encodeCoords[0] && alreadyEncoded.y == encodeCoords[1]
&& alreadyEncoded.width == width && alreadyEncoded.height == height)
return;
this.reencodeCrop(new long[] {width, height});
} else if (alreadyEncoded.x <= neededArea.x && alreadyEncoded.y <= neededArea.y
&& alreadyEncoded.width + alreadyEncoded.x >= neededArea.width + neededArea.x
&& alreadyEncoded.height + alreadyEncoded.y >= neededArea.height + neededArea.y
&& alreadyEncoded.width * 0.7 < neededArea.width && alreadyEncoded.height * 0.7 < neededArea.height
&& notInRect.size() == 0) {
return;
} else if (!alreadyEncoded.equals(rect)) {
this.encodeCoords = new long[] {Math.min(rect.x, neededArea.x), Math.min(rect.y, neededArea.y)};
long[] imgDims = this.img.dimensionsAsLongArray();
long width = Math.min(imgDims[0], Math.max(rect.x + rect.width, neededArea.x + neededArea.width) - encodeCoords[0]);
long height = Math.min(imgDims[1], Math.max(rect.y + rect.height, neededArea.y + neededArea.height) - encodeCoords[1]);
if (alreadyEncoded.x == encodeCoords[0] && alreadyEncoded.y == encodeCoords[1]
&& alreadyEncoded.width == width && alreadyEncoded.height == height)
return;
this.reencodeCrop(new long[] {width, height});
} else {
long[] imgDims = this.img.dimensionsAsLongArray();
long width = neededArea.width;
long height = neededArea.height;
this.encodeCoords[0] = Math.min(neededArea.x, imgDims[0] - width);
this.encodeCoords[1] = Math.min(neededArea.y, imgDims[1] - height);
if (alreadyEncoded.x == encodeCoords[0] && alreadyEncoded.y == encodeCoords[1]
&& alreadyEncoded.width == width && alreadyEncoded.height == height)
return;
this.reencodeCrop(new long[] {width, height});
}
}

private Rectangle getApproximateAreaNeeded(List<int[]> pointsList, List<int[]> pointsNegList) {
ArrayList<int[]> points = new ArrayList<int[]>();
Expand Down Expand Up @@ -1041,11 +979,12 @@ protected long[] calculateEncodingNewCoords(int[] boundingBox, long[] imageSize)
*/
protected void recalculatePolys(List<Mask> masks, long[] encodeCoords) {
masks.stream().forEach(pp -> {
pp.getContour().xpoints = Arrays.stream(pp.getContour().xpoints).map(x -> x + (int) encodeCoords[0]).toArray();
pp.getContour().ypoints = Arrays.stream(pp.getContour().ypoints).map(y -> y + (int) encodeCoords[1]).toArray();
pp.getContour().xpoints = Arrays.stream(pp.getContour().xpoints).map(x -> x * scale + (int) encodeCoords[0]).toArray();
pp.getContour().ypoints = Arrays.stream(pp.getContour().ypoints).map(y -> y * scale + (int) encodeCoords[1]).toArray();
for (int i = 0; i < pp.getRLEMask().length; i += 2) {
pp.getRLEMask()[i] = encodeCoords[0] + pp.getRLEMask()[i] % this.targetDims[0]
+ (((int) (pp.getRLEMask()[i] / this.targetDims[0])) + encodeCoords[1]) * this.targetDims[0];
pp.getRLEMask()[i] = encodeCoords[0] + (pp.getRLEMask()[i] * scale) % this.targetDims[0]
+ (((int) ((pp.getRLEMask()[i] * scale) / this.targetDims[0])) + encodeCoords[1]) * this.targetDims[0];
pp.getRLEMask()[i] *= scale;
}
});
}
Expand Down

0 comments on commit dcc7219

Please sign in to comment.