Skip to content

Commit

Permalink
start redesigning the specific SAMs
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Apr 23, 2024
1 parent af61255 commit 216148b
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 681 deletions.
79 changes: 43 additions & 36 deletions src/main/java/ai/nets/samj/AbstractSamJ2.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,17 @@
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.apposed.appose.Service.Task;
import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;
import io.bioimage.modelrunner.numpy.DecodeNumpy;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import net.imglib2.IterableInterval;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.converter.Converter;
import net.imglib2.converter.Converters;
import net.imglib2.converter.RealTypeConverters;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.RealType;
Expand Down Expand Up @@ -91,44 +94,49 @@ public interface DebugTextPrinter { void printText(String text); }
/**
* Instance referencing the Python environment that is going to be used to run EfficientSAM
*/
private Environment env;
protected Environment env;
/**
* Instance of {@link Service} that is in charge of opening a Python process and running the
* scripts provided in that Python process in order to be able to use EfficientSAM
*/
private Service python;
protected Service python;
/**
* The scripts that want to be run in Python
*/
private String script = "";
protected String script = "";
/**
* Shared memory array used to share between Java and Python the image that wants to be processed by EfficientSAM
*/
private SharedMemoryArray shma;
protected SharedMemoryArray shma;
/**
* Target dimensions of the image that is going to be encoded. If a single-channel 2D image is provided, that image is
* converted into a 3-channel image that EfficientSAM requires
* converted into a 3-channel image that EfficientSAM requires.
* The axes are "xyc"
*/
private long[] targetDims;
protected long[] targetDims;
/**
* Coordinates of the vertex of the crop/zoom of hte image of interest that has been encoded.
* It is the closest vertex to the origin.
* Usually the vertex is at 0,0 and the encoded image is all the pixels. This feature is useful for when the image
* is big and reeconding needs to happen while the user pans and zooms in the image.
* The axes are "xyc"
*/
private long[] encodeCoords;
protected long[] encodeCoords;
/**
* Scale factor of x and y applied to the image that is going to be annotated.
* The image of interest does not need to be encoded normally. However, it is optimal to scale big images
* as the resolution of the segmentation depends on the ratio between the size of the image and the size of
* the object, thus
* The axes are "xyc"
*/
private double[] scale;
protected double[] scale;
/**
* Complete image being annotated. Usually this image is encoded completely
* but for larger images, zooms of it might be encoded instead of the whole image
* but for larger images, zooms of it might be encoded instead of the whole image.
*
* This image is stored as "xyc"
*/
private RandomAccessibleInterval<?> img;
protected RandomAccessibleInterval<?> img;

/**
* Change the image encoded by the EfficientSAM model
Expand All @@ -143,7 +151,6 @@ public interface DebugTextPrinter { void printText(String text); }
public <T extends RealType<T> & NativeType<T>>
void updateImage(RandomAccessibleInterval<T> rai) throws IOException, RuntimeException, InterruptedException {
addImage(rai);
this.img = rai;
}

/**
Expand All @@ -159,12 +166,7 @@ void updateImage(RandomAccessibleInterval<T> rai) throws IOException, RuntimeExc
private <T extends RealType<T> & NativeType<T>>
void addImage(RandomAccessibleInterval<T> rai)
throws IOException, RuntimeException, InterruptedException {
if (rai.dimensionsAsLongArray()[0] * rai.dimensionsAsLongArray()[1] > MAX_ENCODED_AREA_RS * MAX_ENCODED_AREA_RS
|| rai.dimensionsAsLongArray()[0] > MAX_ENCODED_SIDE || rai.dimensionsAsLongArray()[1] > MAX_ENCODED_SIDE) {
this.targetDims = new long[] {0, 0, 0};
this.img = rai;
return;
}
adaptImage(rai);
this.script = "";
sendImgLib2AsNp(rai);
createEncodeImageScript();
Expand All @@ -189,6 +191,8 @@ else if (task.status == TaskStatus.CRASHED)
}
}

protected abstract <T extends RealType<T> & NativeType<T>> void adaptImage(RandomAccessibleInterval<T> rai);

protected abstract void createEncodeImageScript();

private void reencodeCrop() throws IOException, InterruptedException, RuntimeException {
Expand Down Expand Up @@ -741,7 +745,7 @@ public void close() {
private <T extends RealType<T> & NativeType<T>>
void sendImgLib2AsNp(RandomAccessibleInterval<T> targetImg) {
shma = createEfficientSAMInputSHM(reescaleIfNeeded(targetImg));
adaptImageToModel(targetImg, shma.getSharedRAI());
RealTypeConverters.copyFromTo( (RandomAccessible<T>) this.img, shma.getSharedRAI() );
String code = "";
// This line wants to recreate the original numpy array. Should look like:
// input0_appose_shm = shared_memory.SharedMemory(name=input0)
Expand All @@ -751,34 +755,36 @@ void sendImgLib2AsNp(RandomAccessibleInterval<T> targetImg) {
+ ")" + System.lineSeparator();
int size = 1;
for (long l : targetDims) {size *= l;}
code += "im = np.ndarray(" + size + ", dtype='float32', buffer=im_shm.buf).reshape([";
code += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataType((RandomAccessibleInterval<T>) img)
+ "', buffer=im_shm.buf).reshape([";
for (long ll : targetDims)
code += ll + ", ";
code = code.substring(0, code.length() - 2);
code += "])" + System.lineSeparator();
code += "input_h = im.shape[0]" + System.lineSeparator();
code += "input_w = im.shape[1]" + System.lineSeparator();
code += "input_h = im.shape[1]" + System.lineSeparator();
code += "input_w = im.shape[0]" + System.lineSeparator();
code += "globals()['input_h'] = input_h" + System.lineSeparator();
code += "globals()['input_w'] = input_w" + System.lineSeparator();
code += "im = torch.from_numpy(np.transpose(im.astype('float32'), (2, 0, 1)))" + System.lineSeparator();
code += "task.update(str(im.shape))" + System.lineSeparator();
code += "im = torch.from_numpy(np.transpose(im.astype('float32')))" + System.lineSeparator();
code += "task.update('after ' + str(im.shape))" + System.lineSeparator();
code += "im_shm.unlink()" + System.lineSeparator();
//code += "box_shm.close()" + System.lineSeparator();
this.script += code;
}

private <T extends RealType<T> & NativeType<T>>
void sendCropAsNp() {
sendCropAsNp(null);
}

private <T extends RealType<T> & NativeType<T>> void sendCropAsNp(long[] cropSize) {
if (cropSize == null)
cropSize = new long[] {encodeCoords[3] - encodeCoords[1], encodeCoords[2] - encodeCoords[0], 3};
//RandomAccessibleInterval<T> crop =
// Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, interValSize );
cropSize = new long[] {encodeCoords[2] - encodeCoords[0], encodeCoords[3] - encodeCoords[1], 3};
else if (cropSize.length == 2)
cropSize = new long[] {cropSize[0], cropSize[1], 3};
else
throw new IllegalArgumentException("The size of the area that wants to be encoded needs to be defined as [width, height].");
RandomAccessibleInterval<T> crop =
Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize );

//RandomAccessibleInterval<T> crop = Views.offsetInterval(crop, new long[] {encodeCoords[1], encodeCoords[0], 0}, interValSize);
RandomAccessibleInterval<T> crop = Views.offsetInterval(Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize);
//RandomAccessibleInterval<T> crop = Views.offsetInterval(Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, cropSize);
targetDims = crop.dimensionsAsLongArray();
shma = SharedMemoryArray.buildMemorySegmentForImage(new long[] {targetDims[0], targetDims[1], targetDims[2]},
Util.getTypeFromInterval(crop));
Expand All @@ -797,12 +803,12 @@ private <T extends RealType<T> & NativeType<T>> void sendCropAsNp(long[] cropSiz
code += ll + ", ";
code = code.substring(0, code.length() - 2);
code += "])" + System.lineSeparator();
code += "input_h = im.shape[0]" + System.lineSeparator();
code += "input_w = im.shape[1]" + System.lineSeparator();
code += "input_h = im.shape[1]" + System.lineSeparator();
code += "input_w = im.shape[0]" + System.lineSeparator();
//code += "np.save('/home/carlos/git/cropped.npy', im)" + System.lineSeparator();
code += "globals()['input_h'] = input_h" + System.lineSeparator();
code += "globals()['input_w'] = input_w" + System.lineSeparator();
code += "im = torch.from_numpy(np.transpose(im.astype('float32'), (2, 0, 1)))" + System.lineSeparator();
code += "im = torch.from_numpy(np.transpose(im.astype('float32')))" + System.lineSeparator();
code += "im_shm.unlink()" + System.lineSeparator();
//code += "box_shm.close()" + System.lineSeparator();
this.script += code;
Expand All @@ -812,14 +818,15 @@ private <T extends RealType<T> & NativeType<T>> void sendCropAsNp(long[] cropSiz

protected abstract void processBoxWithSAM(boolean returnAll);

private static <T extends RealType<T> & NativeType<T>>
private <T extends RealType<T> & NativeType<T>>
SharedMemoryArray createEfficientSAMInputSHM(final RandomAccessibleInterval<T> inImg) {
long[] dims = inImg.dimensionsAsLongArray();
if ((dims.length != 3 && dims.length != 2) || (dims.length == 3 && dims[2] != 3 && dims[2] != 1)){
throw new IllegalArgumentException("Currently SAMJ only supports 1-channel (grayscale) or 3-channel (RGB, BGR, ...) 2D images."
+ "The image dimensions order should be 'yxc', first dimension height, second width and third channels.");
}
return SharedMemoryArray.buildMemorySegmentForImage(new long[] {dims[0], dims[1], 3}, new FloatType());
return SharedMemoryArray.buildMemorySegmentForImage(new long[] {dims[0], dims[1], 3},
Cast.unchecked(Util.getTypeFromInterval(img)));
}

private <T extends RealType<T> & NativeType<T>>
Expand Down
Loading

0 comments on commit 216148b

Please sign in to comment.