Skip to content

Commit

Permalink
enable downsampling of images to make faster communication
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 30, 2024
1 parent dcc7219 commit 2a52e4b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
32 changes: 19 additions & 13 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ public void setUsinIJRoiManager(boolean isUsingIJRoiManager) {
* @param designationOfTheScript
* the name (or some string to design) of the text that is going to be printed
*/
public void printScript(final String script, final String designationOfTheScript) {
public <T extends RealType<T> & NativeType<T>> void printScript(final String script, final String designationOfTheScript) {
if (!isDebugging) return;
debugPrinter.printText("START: =========== "+designationOfTheScript+" ===========");
debugPrinter.printText(LocalDateTime.now().toString());
Expand Down Expand Up @@ -272,6 +272,8 @@ void setImage(RandomAccessibleInterval<T> rai) throws IOException, RuntimeExcept
this.targetDims = new long[] {0, 0, 0};
this.imageSmall = false;
return;
} else {
scale = 1;
}
this.script = "";
sendImgLib2AsNp();
Expand Down Expand Up @@ -329,7 +331,7 @@ else if (task.status == TaskStatus.CRASHED)

protected <T extends RealType<T> & NativeType<T>>
void sendImgLib2AsNp() {
createSHMArray((RandomAccessibleInterval<T>) this.img);
createSHMArray(Cast.unchecked(this.img));
}

private <T extends RealType<T> & NativeType<T>> void sendCropAsNp(long[] cropSize) {
Expand All @@ -346,13 +348,19 @@ else if (cropSize.length == 3 && cropSize[2] != 3)
targetDims = crop.dimensionsAsLongArray();

scale = (int) (Math.min(targetDims[0], targetDims[1]) / MAX_IMG_SIZE);
RandomAccessibleInterval<T> subsampledCrop = Views.subsample(crop,
new long[] {scale, scale, 1});
targetReescaledDims = crop.dimensionsAsLongArray();
createSHMArray(subsampledCrop);
scale = Math.max(scale, 1);
if (scale == 1) {
createSHMArray(crop);
} else {
RandomAccessibleInterval<T> subsampledCrop = Views.subsample(crop,
new long[] {scale, scale, 1});
targetReescaledDims = subsampledCrop.dimensionsAsLongArray();
createSHMArray(subsampledCrop);
}

}

@SuppressWarnings("unchecked")
private List<Mask> processAndRetrieveContours(HashMap<String, Object> inputs)
throws IOException, RuntimeException, InterruptedException {
Map<String, Object> results = null;
Expand Down Expand Up @@ -388,7 +396,6 @@ else if (task.outputs.get("rle") == null)
final Iterator<List<Number>> contours_y = ((List<List<Number>>)results.get("contours_y")).iterator();
final Iterator<List<Number>> rles = ((List<List<Number>>)results.get("rle")).iterator();
final List<Mask> masks = new ArrayList<Mask>(contours_x_container.size());
Rectangle cropRect = new Rectangle((int) encodeCoords[0], (int) encodeCoords[1], (int) targetDims[0], (int) targetDims[1]);
while (contours_x.hasNext()) {
int[] xArr = contours_x.next().stream().mapToInt(Number::intValue).toArray();
int[] yArr = contours_y.next().stream().mapToInt(Number::intValue).toArray();
Expand Down Expand Up @@ -565,8 +572,8 @@ public List<Mask> processPoints(List<int[]> pointsList, List<int[]> pointsNegLis
private List<int[]> adaptPointPrompts(List<int[]> pointsList) {
pointsList = pointsList.stream().map(pp -> {
int[] newPoint = new int[2];
newPoint[0] = (int) ((pp[0] - this.encodeCoords[0]) / scale);
newPoint[1] = (int) ((pp[1] - this.encodeCoords[1]) / scale);
newPoint[0] = (int) Math.ceil((pp[0] - this.encodeCoords[0]) / (double) scale);
newPoint[1] = (int) Math.ceil((pp[1] - this.encodeCoords[1]) / (double) scale);
return newPoint;
}).collect(Collectors.toList());
return pointsList;
Expand Down Expand Up @@ -619,9 +626,9 @@ public List<Mask> processBox(int[] boundingBox, boolean returnAll)
reencodeCrop();
}
}
int[] adaptedBoundingBox = new int[] {(int) ((boundingBox[0] - encodeCoords[0]) / scale),
(int) ((boundingBox[1] - encodeCoords[1]) / scale),
(int) ((boundingBox[2] - encodeCoords[0]) / scale), (int) ((boundingBox[3] - encodeCoords[1]) / scale)};;
int[] adaptedBoundingBox = new int[] {(int) Math.ceil((boundingBox[0] - encodeCoords[0]) / (double) scale),
(int) Math.ceil((boundingBox[1] - encodeCoords[1]) / (double) scale),
(int) Math.ceil((boundingBox[2] - encodeCoords[0]) / (double) scale), (int) Math.ceil((boundingBox[3] - encodeCoords[1]) / (double) scale)};;
this.script = "";
processBoxWithSAM(returnAll);
HashMap<String, Object> inputs = new HashMap<String, Object>();
Expand Down Expand Up @@ -761,7 +768,6 @@ private void evaluateReencodingNeeded(List<int[]> pointsList, List<int[]> points
Rectangle alreadyEncoded = getCurrentlyEncodedArea();
Rectangle neededArea = getApproximateAreaNeeded(pointsList, pointsNegList, rect);
if (rect.equals(alreadyEncoded)) neededArea = getApproximateAreaNeeded(pointsList, pointsNegList);
ArrayList<int[]> notInRect = getPointsNotInRect(pointsList, pointsNegList, rect);

if (rectContainsRect(alreadyEncoded, neededArea)
&& alreadyEncoded.width * 0.7 < extendedRect.width && alreadyEncoded.height * 0.7 < extendedRect.height) {
Expand Down
5 changes: 0 additions & 5 deletions src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ protected void createEncodeImageScript() {
code += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
code += "task.update('here')" + System.lineSeparator();
int size = 1;
for (long l : targetDims) {size *= l;}
code += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataTypeFromRAI(Cast.unchecked(shma.getSharedRAI()))
Expand All @@ -210,19 +209,15 @@ protected void createEncodeImageScript() {
code += ll + ", ";
code = code.substring(0, code.length() - 2);
code += "])" + System.lineSeparator();
code += "task.update('here')" + System.lineSeparator();
//code += "np.save('/home/carlos/git/crop.npy', im)" + System.lineSeparator();
code += "input_h = im.shape[1]" + System.lineSeparator();
code += "input_w = im.shape[0]" + System.lineSeparator();
code += "globals()['input_h'] = input_h" + System.lineSeparator();
code += "globals()['input_w'] = input_w" + System.lineSeparator();
code += "task.update('here')" + System.lineSeparator();
//code += "task.update(str(im.shape))" + System.lineSeparator();
code += "im = torch.from_numpy(np.transpose(im, (2, 1, 0)))" + System.lineSeparator();
code += "task.update('here')" + System.lineSeparator();
//code += "task.update('after ' + str(im.shape))" + System.lineSeparator();
code += "im_shm.unlink()" + System.lineSeparator();
code += "task.update('here')" + System.lineSeparator();
this.script += code;
this.script += ""
+ "_ = predictor.get_image_embeddings(im[None, ...])" + System.lineSeparator();
Expand Down
17 changes: 11 additions & 6 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package ai.nets.samj.models;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import ai.nets.samj.install.Sam2EnvManager;
Expand All @@ -33,6 +34,7 @@

import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import net.imglib2.Cursor;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.converter.RealTypeConverters;
import net.imglib2.img.array.ArrayImgs;
Expand Down Expand Up @@ -286,13 +288,15 @@ protected void createEncodeImageScript() {
script += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
int size = 1;
for (long l : targetDims) {size *= l;}
int size = (int) targetDims[2];
for (int i = 0; i < targetDims.length - 1; i ++) {
size *= Math.ceil(targetDims[i] / (double) scale);
}
script += "im = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataTypeFromRAI(Cast.unchecked(shma.getSharedRAI()))
+ "', buffer=im_shm.buf).reshape([";
for (long ll : targetDims)
script += ll + ", ";
script = script.substring(0, script.length() - 2);
for (int i = 0; i < targetDims.length - 1; i ++)
script += (int) Math.ceil(targetDims[i] / (double) scale) + ", ";
script += targetDims[2];
script += "])" + System.lineSeparator();
script += "im = np.transpose(im, (1, 0, 2))" + System.lineSeparator();
//code += "np.save('/home/carlos/git/aa.npy', im)" + System.lineSeparator();
Expand Down Expand Up @@ -436,9 +440,10 @@ public static List<String> getListOfSupportedVariants(){
private <T extends RealType<T> & NativeType<T>>
void adaptImageToModel(RandomAccessibleInterval<T> ogImg, RandomAccessibleInterval<T> targetImg) {
if (ogImg.numDimensions() == 3 && ogImg.dimensionsAsLongArray()[2] == 3) {
for (int i = 0; i < 3; i ++)
for (int i = 0; i < 3; i ++) {
RealTypeConverters.copyFromTo( ImgLib2Utils.convertViewToRGB(Views.hyperSlice(ogImg, 2, i), this.debugPrinter),
Views.hyperSlice(targetImg, 2, i) );
}
} else if (ogImg.numDimensions() == 3 && ogImg.dimensionsAsLongArray()[2] == 1) {
debugPrinter.printText("CONVERTED 1 CHANNEL IMAGE INTO 3 TO BE FEEDED TO SAMJ");
IntervalView<UnsignedByteType> resIm =
Expand Down

0 comments on commit 2a52e4b

Please sign in to comment.