Skip to content

Commit

Permalink
add api calls to save encodings and reload them
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Jul 16, 2024
1 parent 199656a commit 65ef7c4
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 2 deletions.
70 changes: 70 additions & 0 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.stream.Collectors;
import java.awt.Polygon;
import java.awt.Rectangle;
Expand Down Expand Up @@ -141,6 +142,17 @@ public interface DebugTextPrinter { void printText(String text); }
*/
protected boolean imageSmall = true;

/**
* List of encodings that are cached to avoid recalculating
*/
List<String> savedEncodings = new ArrayList<String>();

protected abstract String persistEncodingScript(String encodingName);

protected abstract String selectEncodingScript(String encodingName);

protected abstract String deleteEncodingScript(String encodingName);

protected abstract void cellSAM(List<int[]> grid, boolean returnAll);

protected abstract void processPointsWithSAM(int nPoints, int nNegPoints, boolean returnAll);
Expand Down Expand Up @@ -1038,4 +1050,62 @@ protected void recalculatePolys(List<Polygon> polys, long[] encodeCoords) {
pp.ypoints = Arrays.stream(pp.ypoints).map(y -> y + (int) encodeCoords[1]).toArray();
});
}

public String persistEncoding() throws IOException, InterruptedException {
String uuid = UUID.randomUUID().toString();
String saveEncodings = persistEncodingScript(uuid);
try {
Task task = python.task(saveEncodings);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException();
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException();
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException();
} catch (IOException | InterruptedException | RuntimeException e) {
throw e;
}
this.savedEncodings.add(uuid);
return uuid;
}

public void selectEncoding(String encodingName) throws IOException, InterruptedException {
if (!this.savedEncodings.contains(encodingName))
throw new IllegalArgumentException("No saved encoding found with name: " + encodingName);
String setEncoding = selectEncodingScript(encodingName);
try {
Task task = python.task(setEncoding);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException();
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException();
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException();
} catch (IOException | InterruptedException | RuntimeException e) {
throw e;
}

}


public void deleteEncoding(String encodingName) throws IOException, InterruptedException {
if (!this.savedEncodings.contains(encodingName))
return;
String returnEncoding = deleteEncodingScript(encodingName);
try {
Task task = python.task(returnEncoding);
task.waitFor();
if (task.status == TaskStatus.CANCELED)
throw new RuntimeException();
else if (task.status == TaskStatus.FAILED)
throw new RuntimeException();
else if (task.status == TaskStatus.CRASHED)
throw new RuntimeException();
} catch (IOException | InterruptedException | RuntimeException e) {
throw e;
}
this.savedEncodings.remove(encodingName);
}
}
27 changes: 26 additions & 1 deletion src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import io.bioimage.modelrunner.apposed.appose.Environment;
Expand All @@ -48,6 +49,12 @@
* @author vladimir Ulman
*/
public class EfficientSamJ extends AbstractSamJ {

/**
* List of encodings that are cached to avoid recalculating
*/
List<String> savedEncodings = new ArrayList<String>();

/**
* All the Python imports and configurations needed to start using EfficientSAM.
*/
Expand All @@ -66,6 +73,8 @@ public class EfficientSamJ extends AbstractSamJ {
+ "" + System.lineSeparator()
+ "predictor = build_efficient_sam(encoder_patch_embed_dim=384,encoder_num_heads=6,checkpoint=r'%s',).eval()" + System.lineSeparator()
+ "task.update('created predictor')" + System.lineSeparator()
+ "encodings_map = {}" + System.lineSeparator()
+ "globals()['encodings_map'] = encodings_map" + System.lineSeparator()
+ "globals()['shared_memory'] = shared_memory" + System.lineSeparator()
+ "globals()['measure'] = measure" + System.lineSeparator()
+ "globals()['np'] = np" + System.lineSeparator()
Expand Down Expand Up @@ -226,7 +235,7 @@ protected void createEncodeImageScript() {
this.script += code;
this.script += ""
+ "task.update(str(im.shape))" + System.lineSeparator()
+ "aa = predictor.get_image_embeddings(im[None, ...])";
+ "predictor.get_image_embeddings(im[None, ...])";
}

@Override
Expand Down Expand Up @@ -424,4 +433,20 @@ protected <T extends RealType<T> & NativeType<T>> void createSHMArray(RandomAcce
shma = SharedMemoryArray.create(new long[] {dims[0], dims[1], dims[2]}, new FloatType(), false, false);
adaptImageToModel(imageToBeSent, shma.getSharedRAI());
}

@Override
public String persistEncodingScript(String encodingName) {
return "encodings_map['" + encodingName + "'] = predictor.encoded_images";
}

@Override
public String selectEncodingScript(String encodingName) {
return "predictor.encoded_images = encodings_map['" + encodingName + "']";

}

@Override
public String deleteEncodingScript(String encodingName) {
return "del encodings_map['" + encodingName + "']";
}
}
24 changes: 23 additions & 1 deletion src/main/java/ai/nets/samj/models/EfficientViTSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
*/
package ai.nets.samj.models;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
Expand All @@ -40,7 +41,6 @@
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Intervals;
import net.imglib2.view.IntervalView;
Expand All @@ -53,6 +53,11 @@
* @author Vladimir Ulman
*/
public class EfficientViTSamJ extends AbstractSamJ {

/**
* List of encodings that are cached to avoid recalculating
*/
List<String> savedEncodings = new ArrayList<String>();
/**
* Map that associates the key for each of the existing EfficientViTSAM models to its complete name
*/
Expand Down Expand Up @@ -95,6 +100,8 @@ public class EfficientViTSamJ extends AbstractSamJ {
+ "model.load_state_dict(weight)" + System.lineSeparator()
+ "predictor = EfficientViTSamPredictor(model)" + System.lineSeparator()
+ "task.update('created predictor')" + System.lineSeparator()
+ "encodings_map = {}" + System.lineSeparator()
+ "globals()['encodings_map'] = encodings_map" + System.lineSeparator()
+ "globals()['shared_memory'] = shared_memory" + System.lineSeparator()
+ "globals()['measure'] = measure" + System.lineSeparator()
+ "globals()['np'] = np" + System.lineSeparator()
Expand Down Expand Up @@ -499,4 +506,19 @@ protected void cellSAM(List<int[]> grid, boolean returnAll) {
// TODO Auto-generated method stub

}

@Override
public String persistEncodingScript(String encodingName) {
return "encodings_map['" + encodingName + "'] = predictor.features";
}

@Override
public String selectEncodingScript(String encodingName) {
return "predictor.features = encodings_map['" + encodingName + "']";
}

@Override
public String deleteEncodingScript(String encodingName) {
return "del encodings_map['" + encodingName + "']";
}
}

0 comments on commit 65ef7c4

Please sign in to comment.