Skip to content

Commit

Permalink
working prototype in efficeint same with big images
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Apr 19, 2024
1 parent 40e5541 commit 11f26e9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
18 changes: 10 additions & 8 deletions src/main/java/ai/nets/samj/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ public class AbstractSamJ {

protected static int LOWER_REENCODE_THRESH = 50;

protected static int OPTIMAL_BBOX_IM_RATIO = 10;

protected static double UPPER_REENCODE_THRESH = 1.1;

protected static long MAX_ENCODED_IMAGE_SIZE = 2048;
Expand Down Expand Up @@ -257,20 +259,20 @@ RandomAccessibleInterval<UnsignedByteType> convertViewToRGB(final RandomAccessib
protected long[] calculateEncodingNewCoords(int[] boundingBox, long[] imageSize) {
long xSize = boundingBox[2] - boundingBox[0];
long ySize = boundingBox[3] - boundingBox[1];
long smallerSize = ySize < xSize ? ySize * LOWER_REENCODE_THRESH : xSize * LOWER_REENCODE_THRESH;
long smallerSize = ySize < xSize ? ySize * OPTIMAL_BBOX_IM_RATIO : xSize * OPTIMAL_BBOX_IM_RATIO;
long biggerSize = smallerSize * 3;
if ((ySize < xSize) && (ySize * 3 > xSize)) {
biggerSize = xSize * 3;
biggerSize = xSize * OPTIMAL_BBOX_IM_RATIO;
} else if ((ySize > xSize) && (xSize * 3 > ySize)) {
biggerSize = ySize * 3;
biggerSize = ySize * OPTIMAL_BBOX_IM_RATIO;
}
long[] newSize = new long[] {biggerSize, smallerSize};
if (ySize > xSize) newSize = new long[] {smallerSize, biggerSize};
long[] posWrtBbox = new long[4];
posWrtBbox[0] = boundingBox[0] - (long) Math.ceil(newSize[0] / 2);
posWrtBbox[1] = boundingBox[1] - (long) Math.ceil(newSize[1] / 2);
posWrtBbox[2] = boundingBox[2] + (long) Math.floor(newSize[0] / 2);
posWrtBbox[3] = boundingBox[3] + (long) Math.floor(newSize[1] / 2);
posWrtBbox[0] = (long) Math.ceil((boundingBox[0] + xSize / 2) - newSize[0] / 2);
posWrtBbox[1] = (long) Math.ceil((boundingBox[1] + ySize / 2) - newSize[1] / 2);
posWrtBbox[2] = (long) Math.floor((boundingBox[2] + xSize / 2) + newSize[0] / 2);
posWrtBbox[3] = (long) Math.floor((boundingBox[3] + ySize / 2) + newSize[1] / 2);
return posWrtBbox;
}

Expand All @@ -289,7 +291,7 @@ protected long[] calculateEncodingNewCoords(int[] boundingBox, long[] imageSize)
protected void recalculatePolys(List<Polygon> polys, long[] encodeCoords) {
polys.stream().forEach(pp -> {
pp.xpoints = Arrays.stream(pp.xpoints).map(x -> x + (int) encodeCoords[0]).toArray();
pp.ypoints = Arrays.stream(pp.ypoints).map(y -> y + (int) encodeCoords[0]).toArray();
pp.ypoints = Arrays.stream(pp.ypoints).map(y -> y + (int) encodeCoords[1]).toArray();
});
}
}
52 changes: 40 additions & 12 deletions src/main/java/ai/nets/samj/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;

import java.awt.Polygon;
import java.io.File;
import java.io.IOException;
Expand All @@ -33,17 +34,19 @@
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.apposed.appose.Service.Task;
import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;

import io.bioimage.modelrunner.numpy.DecodeNumpy;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.converter.RealTypeConverters;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Intervals;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

Expand Down Expand Up @@ -81,7 +84,7 @@ public class EfficientSamJ extends AbstractSamJ implements AutoCloseable {
* Usually the vertex is at 0,0 and the encoded image is all the pixels. This feature is useful for when the image
* is big and reeconding needs to happen while the user pans and zooms in the image.
*/
private long[] encodeCoords = new long[] {0, 0, 0};
private long[] encodeCoords;
/**
* Scale factor of x and y applied to the image that is going to be annotated.
* The image of interest does not need to be encoded normally. However, it is optimal to scale big images
Expand Down Expand Up @@ -620,13 +623,15 @@ public List<Polygon> processBox(int[] boundingBox)
*/
public List<Polygon> processBox(int[] boundingBox, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
int[] adaptedBoundingBox = boundingBox;
int[] adaptedBoundingBox = new int[] {(int) (boundingBox[0] - encodeCoords[0]), (int) (boundingBox[1] - encodeCoords[1]),
(int) (boundingBox[2] - encodeCoords[0]), (int) (boundingBox[3] - encodeCoords[1])};;
if (needsMoreResolution(boundingBox)) {
long[] cropPosWrtBbox = calculateEncodingNewCoords(boundingBox, this.img.dimensionsAsLongArray());
this.encodeCoords = new long[] {boundingBox[0] + cropPosWrtBbox[0], boundingBox[1] + cropPosWrtBbox[1],
boundingBox[2] + cropPosWrtBbox[2], boundingBox[3] + cropPosWrtBbox[3]};
adaptedBoundingBox = new int[] {(int) -cropPosWrtBbox[0], (int) -cropPosWrtBbox[1],
(int) (boundingBox[2] + cropPosWrtBbox[2]), (int) (boundingBox[3] + cropPosWrtBbox[3])};
this.encodeCoords = calculateEncodingNewCoords(boundingBox, this.img.dimensionsAsLongArray());
reencodeCrop();
} else if (!isAreaEncoded(boundingBox)) {
this.encodeCoords = calculateEncodingNewCoords(boundingBox, this.img.dimensionsAsLongArray());
adaptedBoundingBox = new int[] {(int) (boundingBox[0] - encodeCoords[0]), (int) (boundingBox[1] - encodeCoords[1]),
(int) (boundingBox[2] - encodeCoords[0]), (int) (boundingBox[3] - encodeCoords[1])};
reencodeCrop();
}
this.script = "";
Expand All @@ -639,6 +644,23 @@ public List<Polygon> processBox(int[] boundingBox, boolean returnAll)
debugPrinter.printText("processBox() obtained " + polys.size() + " polygons");
return polys;
}

/**
* Check whether the bounding box is inside the area that is encoded or not
* @param boundingBox
* the vertices of the bounding box
* @return whether the bounding box is within the encoded area or not
*/
public boolean isAreaEncoded(int[] boundingBox) {
boolean upperLeftVertex = (boundingBox[0] > this.encodeCoords[0]) && (boundingBox[0] < this.encodeCoords[2]);
boolean upperRightVertex = (boundingBox[2] > this.encodeCoords[0]) && (boundingBox[2] < this.encodeCoords[2]);
boolean downLeftVertex = (boundingBox[1] > this.encodeCoords[1]) && (boundingBox[1] < this.encodeCoords[3]);
boolean downRightVertex = (boundingBox[3] > this.encodeCoords[1]) && (boundingBox[3] < this.encodeCoords[3]);

if (upperLeftVertex && upperRightVertex && downLeftVertex && downRightVertex)
return true;
return false;
}

/**
* For bounding box masks, check whether the its size is too small compared to the size
Expand Down Expand Up @@ -711,11 +733,15 @@ void sendImgLib2AsNp(RandomAccessibleInterval<T> targetImg) {

private <T extends RealType<T> & NativeType<T>>
void sendCropAsNp() {
long[] intervalSize = new long[] {encodeCoords[3] - encodeCoords[1], encodeCoords[2] - encodeCoords[0], img.dimensionsAsLongArray()[2]};
RandomAccessibleInterval<T> crop = Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, intervalSize );
long[] intervalMax = new long[] {encodeCoords[3] - encodeCoords[1], encodeCoords[2] - encodeCoords[0], 3};
RandomAccessibleInterval<T> crop =
Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, intervalMax );

shma = createEfficientSAMInputSHM(reescaleIfNeeded(crop));
adaptImageToModel(crop, shma.getSharedRAI());
crop = Views.offsetInterval(crop, new long[] {encodeCoords[1], encodeCoords[0], 0}, intervalMax);
targetDims = crop.dimensionsAsLongArray();
shma = SharedMemoryArray.buildMemorySegmentForImage(new long[] {targetDims[0], targetDims[1], targetDims[2]},
Util.getTypeFromInterval(crop));
RealTypeConverters.copyFromTo(crop, shma.getSharedRAI());
String code = "";
// This line wants to recreate the original numpy array. Should look like:
// input0_appose_shm = shared_memory.SharedMemory(name=input0)
Expand All @@ -732,6 +758,7 @@ void sendCropAsNp() {
code += "])" + System.lineSeparator();
code += "input_h = im.shape[0]" + System.lineSeparator();
code += "input_w = im.shape[1]" + System.lineSeparator();
code += "np.save('/home/carlos/git/cropped.npy', im)" + System.lineSeparator();
code += "globals()['input_h'] = input_h" + System.lineSeparator();
code += "globals()['input_w'] = input_w" + System.lineSeparator();
code += "im = torch.from_numpy(np.transpose(im.astype('float32'), (2, 0, 1)))" + System.lineSeparator();
Expand Down Expand Up @@ -834,6 +861,7 @@ void adaptImageToModel(final RandomAccessibleInterval<T> ogImg, RandomAccessible
throw new IllegalArgumentException("Currently SAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, ...) 2D images."
+ "The image dimensions order should be 'yxc', first dimension height, second width and third channels.");
}
this.img = targetImg;
this.targetDims = targetImg.dimensionsAsLongArray();
}

Expand Down

0 comments on commit 11f26e9

Please sign in to comment.