Skip to content

Commit

Permalink
first try to see if it works
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Apr 19, 2024
1 parent 80d00e3 commit 40e5541
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 11 deletions.
22 changes: 22 additions & 0 deletions src/main/java/ai/nets/samj/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
import net.imglib2.util.Util;
import net.imglib2.view.Views;

import java.awt.Polygon;
import java.time.LocalDateTime;
import java.util.Arrays;
import java.util.List;

/**
* Class that contains methods that can be sued by SAMJ models
Expand Down Expand Up @@ -270,4 +273,23 @@ protected long[] calculateEncodingNewCoords(int[] boundingBox, long[] imageSize)
posWrtBbox[3] = boundingBox[3] + (long) Math.floor(newSize[1] / 2);
return posWrtBbox;
}

/**
* Method that recalculates the coordinates of the polygons outputed by SAMJ.
*
* This method is usually for big images. In order to create encoding with enough resolution
* to detect small objects compared to the size of the whole image, SAMJ might encode crops of
* the total image, thus the coordinates of the polygons obtained need to be shifted in order
* to match the original image.
* @param polys
* polys obtained by SAMJ on the encoded crop
* @param encodeCoords
* position of the crop in the total image
*/
protected void recalculatePolys(List<Polygon> polys, long[] encodeCoords) {
polys.stream().forEach(pp -> {
pp.xpoints = Arrays.stream(pp.xpoints).map(x -> x + (int) encodeCoords[0]).toArray();
pp.ypoints = Arrays.stream(pp.ypoints).map(y -> y + (int) encodeCoords[0]).toArray();
});
}
}
82 changes: 71 additions & 11 deletions src/main/java/ai/nets/samj/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;
Expand Down Expand Up @@ -80,7 +81,7 @@ public class EfficientSamJ extends AbstractSamJ implements AutoCloseable {
* Usually the vertex is at 0,0 and the encoded image is all the pixels. This feature is useful for when the image
* is big and reeconding needs to happen while the user pans and zooms in the image.
*/
private long[] encodeCoords;
private long[] encodeCoords = new long[] {0, 0, 0};
/**
* Scale factor of x and y applied to the image that is going to be annotated.
* The image of interest does not need to be encoded normally. However, it is optimal to scale big images
Expand Down Expand Up @@ -116,11 +117,6 @@ public class EfficientSamJ extends AbstractSamJ implements AutoCloseable {
+ "globals()['np'] = np" + System.lineSeparator()
+ "globals()['torch'] = torch" + System.lineSeparator()
+ "globals()['predictor'] = predictor" + System.lineSeparator();
/**
* String containing the Python imports code after it has been formatted with the correct
* paths and names
*/
private String IMPORTS_FORMATED;

/**
* Create an instance of the class to be able to run EfficientSAM in Java.
Expand Down Expand Up @@ -162,7 +158,7 @@ private EfficientSamJ(SamEnvManager manager,
};
python = env.python();
python.debug(debugPrinter::printText);
IMPORTS_FORMATED = String.format(IMPORTS,
String IMPORTS_FORMATED = String.format(IMPORTS,
manager.getEfficientSamEnv() + File.separator + SamEnvManager.ESAM_NAME,
manager.getEfficientSAMSmallWeightsPath());
printScript(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES, "Edges tracing code");
Expand Down Expand Up @@ -301,6 +297,34 @@ else if (task.status == TaskStatus.CRASHED)
}
}

private void reencodeCrop() throws IOException, InterruptedException, RuntimeException {
this.script = "";
sendCropAsNp();
this.script += ""
+ "task.update(str(im.shape))" + System.lineSeparator()
+ "aa = predictor.get_image_embeddings(im[None, ...])";
try {
printScript(script, "Creation of the cropped embeddings");
Task task = python.task(script);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException();
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException();
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException();
this.shma.close();
} catch (IOException | InterruptedException | RuntimeException e) {
try {
this.shma.close();
} catch (IOException e1) {
throw new IOException(e.toString() + System.lineSeparator() + e1.toString());
}
throw e;
}

}

private List<Polygon> processAndRetrieveContours(HashMap<String, Object> inputs)
throws IOException, RuntimeException, InterruptedException {
Map<String, Object> results = null;
Expand Down Expand Up @@ -596,17 +620,22 @@ public List<Polygon> processBox(int[] boundingBox)
*/
public List<Polygon> processBox(int[] boundingBox, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
int[] adaptedBoundingBox = boundingBox;
if (needsMoreResolution(boundingBox)) {
long[] cropPosWrtBbox = calculateEncodingNewCoords(boundingBox, this.img.dimensionsAsLongArray());
reencodeCrop(cropPosWrtBbox);
this.encodeCoords = new long[] {boundingBox[0] + cropPosWrtBbox[0], boundingBox[1] + cropPosWrtBbox[1],
boundingBox[2] + cropPosWrtBbox[2], boundingBox[3] + cropPosWrtBbox[3]};
adaptedBoundingBox = new int[] {(int) -cropPosWrtBbox[0], (int) -cropPosWrtBbox[1],
(int) (boundingBox[2] + cropPosWrtBbox[2]), (int) (boundingBox[3] + cropPosWrtBbox[3])};
reencodeCrop();
}
boundingBox = recalculateBbox(boundingBox);
this.script = "";
processBoxWithSAM(returnAll);
HashMap<String, Object> inputs = new HashMap<String, Object>();
inputs.put("input_box", boundingBox);
inputs.put("input_box", adaptedBoundingBox);
printScript(script, "Rectangle inference");
List<Polygon> polys = processAndRetrieveContours(inputs);
recalculatePolys(polys, encodeCoords);
debugPrinter.printText("processBox() obtained " + polys.size() + " polygons");
return polys;
}
Expand Down Expand Up @@ -660,7 +689,38 @@ void sendImgLib2AsNp(RandomAccessibleInterval<T> targetImg) {
// This line wants to recreate the original numpy array. Should look like:
// input0_appose_shm = shared_memory.SharedMemory(name=input0)
// input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64])
code += IMPORTS_FORMATED+"im_shm = shared_memory.SharedMemory(name='"
code += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
int size = 1;
for (long l : targetDims) {size *= l;}
code += "im = np.ndarray(" + size + ", dtype='float32', buffer=im_shm.buf).reshape([";
for (long ll : targetDims)
code += ll + ", ";
code = code.substring(0, code.length() - 2);
code += "])" + System.lineSeparator();
code += "input_h = im.shape[0]" + System.lineSeparator();
code += "input_w = im.shape[1]" + System.lineSeparator();
code += "globals()['input_h'] = input_h" + System.lineSeparator();
code += "globals()['input_w'] = input_w" + System.lineSeparator();
code += "im = torch.from_numpy(np.transpose(im.astype('float32'), (2, 0, 1)))" + System.lineSeparator();
code += "im_shm.unlink()" + System.lineSeparator();
//code += "box_shm.close()" + System.lineSeparator();
this.script += code;
}

private <T extends RealType<T> & NativeType<T>>
void sendCropAsNp() {
long[] intervalSize = new long[] {encodeCoords[3] - encodeCoords[1], encodeCoords[2] - encodeCoords[0], img.dimensionsAsLongArray()[2]};
RandomAccessibleInterval<T> crop = Views.interval( Cast.unchecked(img), new long[] {encodeCoords[1], encodeCoords[0], 0}, intervalSize );

shma = createEfficientSAMInputSHM(reescaleIfNeeded(crop));
adaptImageToModel(crop, shma.getSharedRAI());
String code = "";
// This line wants to recreate the original numpy array. Should look like:
// input0_appose_shm = shared_memory.SharedMemory(name=input0)
// input0 = np.ndarray(size, dtype="float64", buffer=input0_appose_shm.buf).reshape([64, 64])
code += "im_shm = shared_memory.SharedMemory(name='"
+ shma.getNameForPython() + "', size=" + shma.getSize()
+ ")" + System.lineSeparator();
int size = 1;
Expand Down

0 comments on commit 40e5541

Please sign in to comment.