Skip to content

Commit

Permalink
update JDLL version, leave responsability of creating the code for
Browse files Browse the repository at this point in the history
encodings on SAM specific model
  • Loading branch information
carlosuc3m committed Apr 23, 2024
1 parent b78afd5 commit 7ad34f5
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@

<!-- NB: Deploy releases to the SciJava Maven repository. -->
<releaseProfiles>sign,deploy-to-scijava</releaseProfiles>
<dl-modelrunner.version>0.5.5</dl-modelrunner.version>
<dl-modelrunner.version>0.5.8-SNAPSHOT</dl-modelrunner.version>
</properties>

<dependencies>
Expand Down
29 changes: 2 additions & 27 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -268,31 +268,6 @@ else if (task.status == TaskStatus.CRASHED)
protected <T extends RealType<T> & NativeType<T>>
void sendImgLib2AsNp() {
createSHMArray((RandomAccessibleInterval<T>) this.img);
String code = "";
// This line wants to recreate the original numpy array. Should look like:
// input0_appose_shm = shared_memory.SharedMemory(name=input0)
// input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64])
code += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
int size = 1;
for (long l : targetDims) {size *= l;}
code += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataType(Util.getTypeFromInterval(shma.getSharedRAI()))
+ "', buffer=im_shm.buf).reshape([";
for (long ll : targetDims)
code += ll + ", ";
code = code.substring(0, code.length() - 2);
code += "])" + System.lineSeparator();
code += "input_h = im.shape[1]" + System.lineSeparator();
code += "input_w = im.shape[0]" + System.lineSeparator();
code += "globals()['input_h'] = input_h" + System.lineSeparator();
code += "globals()['input_w'] = input_w" + System.lineSeparator();
code += "task.update(str(im.shape))" + System.lineSeparator();
code += "im = torch.from_numpy(np.transpose(im))" + System.lineSeparator();
code += "task.update('after ' + str(im.shape))" + System.lineSeparator();
code += "im_shm.unlink()" + System.lineSeparator();
//code += "box_shm.close()" + System.lineSeparator();
this.script += code;
}

private <T extends RealType<T> & NativeType<T>> void sendCropAsNp(long[] cropSize) {
Expand All @@ -318,7 +293,7 @@ else if (cropSize.length == 2)
+ ")" + System.lineSeparator();
int size = 1;
for (long l : targetDims) {size *= l;}
code += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataType(Util.getTypeFromInterval(crop)) + "', buffer=im_shm.buf).reshape([";
code += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataTypeFromRAI(crop) + "', buffer=im_shm.buf).reshape([";
for (long ll : targetDims)
code += ll + ", ";
code = code.substring(0, code.length() - 2);
Expand Down Expand Up @@ -637,7 +612,7 @@ List<Polygon> processMask(RandomAccessibleInterval<T> img, boolean returnAll)
throw new IllegalArgumentException("The provided mask should be a 2d image with just one channel of width "
+ this.shma.getOriginalShape()[1] + " and height " + this.shma.getOriginalShape()[0]);
}
SharedMemoryArray maskShma = SharedMemoryArray.buildSHMA(img);
SharedMemoryArray maskShma = SharedMemoryArray.createSHMAFromRAI(img, false, false);
try {
return processMask(maskShma, returnAll);
} catch (IOException | RuntimeException | InterruptedException ex) {
Expand Down
31 changes: 29 additions & 2 deletions src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
import io.bioimage.modelrunner.apposed.appose.Service.Task;
import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.converter.RealTypeConverters;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;
Expand Down Expand Up @@ -196,7 +198,32 @@ else if (task.status == TaskStatus.CRASHED)

@Override
protected void createEncodeImageScript() {
this.script = ""
String code = "";
// This line wants to recreate the original numpy array. Should look like:
// input0_appose_shm = shared_memory.SharedMemory(name=input0)
// input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64])
code += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
int size = 1;
for (long l : targetDims) {size *= l;}
code += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataTypeFromRAI(Cast.unchecked(shma.getSharedRAI()))
+ "', buffer=im_shm.buf).reshape([";
for (long ll : targetDims)
code += ll + ", ";
code = code.substring(0, code.length() - 2);
code += "])" + System.lineSeparator();
code += "input_h = im.shape[1]" + System.lineSeparator();
code += "input_w = im.shape[0]" + System.lineSeparator();
code += "globals()['input_h'] = input_h" + System.lineSeparator();
code += "globals()['input_w'] = input_w" + System.lineSeparator();
code += "task.update(str(im.shape))" + System.lineSeparator();
code += "im = torch.from_numpy(np.transpose(im))" + System.lineSeparator();
code += "task.update('after ' + str(im.shape))" + System.lineSeparator();
code += "im_shm.unlink()" + System.lineSeparator();
//code += "box_shm.close()" + System.lineSeparator();
this.script += code;
this.script += ""
+ "task.update(str(im.shape))" + System.lineSeparator()
+ "aa = predictor.get_image_embeddings(im[None, ...])";
}
Expand Down Expand Up @@ -379,7 +406,7 @@ private <T extends RealType<T> & NativeType<T>> void checkImageIsFine(RandomAcce
protected <T extends RealType<T> & NativeType<T>> void createSHMArray(RandomAccessibleInterval<T> imShared) {
RandomAccessibleInterval<T> imageToBeSent = ImgLib2Utils.reescaleIfNeeded(imShared);
long[] dims = imageToBeSent.dimensionsAsLongArray();
shma = SharedMemoryArray.buildMemorySegmentForImage(new long[] {dims[0], dims[1], dims[2]}, new FloatType());
shma = SharedMemoryArray.create(new long[] {dims[0], dims[1], dims[2]}, new FloatType(), false, false);
adaptImageToModel(imageToBeSent, shma.getSharedRAI());
}
}
22 changes: 20 additions & 2 deletions src/main/java/ai/nets/samj/models/EfficientViTSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,15 @@
import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;

import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.converter.RealTypeConverters;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;
Expand Down Expand Up @@ -304,7 +307,22 @@ protected <T extends RealType<T> & NativeType<T>> void setImageOfInterest(Random

@Override
protected void createEncodeImageScript() {
this.script += ""
script = "";
script += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
int size = 1;
for (long l : targetDims) {size *= l;}
script += "im = np.ndarray(" + size + ", dtype='" + "', buffer=im_shm.buf).reshape([";
for (long ll : targetDims)
script += ll + ", ";
script = script.substring(0, script.length() - 2);
script += "])" + System.lineSeparator();
script += "im = np.transpose(im, (1, 0, 2))" + System.lineSeparator();
//code += "np.save('/home/carlos/git/aa.npy', im)" + System.lineSeparator();
script += "im_shm.unlink()" + System.lineSeparator();
//code += "box_shm.close()" + System.lineSeparator();
script += ""
+ "task.update(str(im.shape))" + System.lineSeparator()
+ "predictor.set_image(im)";
}
Expand All @@ -313,7 +331,7 @@ protected void createEncodeImageScript() {
protected <T extends RealType<T> & NativeType<T>> void createSHMArray(RandomAccessibleInterval<T> imShared) {
RandomAccessibleInterval<T> imageToBeSent = ImgLib2Utils.reescaleIfNeeded(imShared);
long[] dims = imageToBeSent.dimensionsAsLongArray();
shma = SharedMemoryArray.buildMemorySegmentForImage(new long[] {dims[0], dims[1], dims[2]}, new UnsignedByteType());
shma = SharedMemoryArray.create(new long[] {dims[0], dims[1], dims[2]}, new UnsignedByteType(), false, false);
adaptImageToModel(imageToBeSent, shma.getSharedRAI());
}

Expand Down

0 comments on commit 7ad34f5

Please sign in to comment.