Skip to content

Commit

Permalink
keep doing
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Dec 2, 2024
1 parent b84afa1 commit 54a9d6e
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 65 deletions.
5 changes: 3 additions & 2 deletions src/main/java/ai/nets/samj/communication/model/SAM2Tiny.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ public <T extends RealType<T> & NativeType<T>> void setImage(final RandomAccessi
if (useThisLoggerForIt != null)
this.log = useThisLoggerForIt;
AbstractSamJ.DebugTextPrinter filteringLogger = text -> {
int idx = text.indexOf("contours_x");
if (idx > 0) this.log.info( text.substring(0,idx) );
int idx = text.indexOf("\"responseType\": \"COMPLETION\"");
int idxProgress = text.indexOf("\"message\": \"8f821f82-db6f-42a3-8500-794a5033114e\"");
if (idx > 0) this.log.info( text.substring(0,idx) + "\"responseType\": \"COMPLETION\"}");
else this.log.info( text );
};
if (this.samj == null)
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/ai/nets/samj/communication/model/SAMModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import ai.nets.samj.annotation.Mask;
import ai.nets.samj.install.SamEnvManagerAbstract;
import ai.nets.samj.models.AbstractSamJ;
import ai.nets.samj.models.AbstractSamJ.BatchCallback;
import ai.nets.samj.ui.SAMJLogger;
import net.imglib2.Interval;
import net.imglib2.Localizable;
Expand Down Expand Up @@ -185,6 +186,12 @@ List<Mask> processBatchOfPrompts(List<int[]> points, List<Rectangle> rects, Rand
return samj.processBatchOfPrompts(points, rects, rai, !onlyBiggest);
}

public <T extends RealType<T> & NativeType<T>>
List<Mask> processBatchOfPrompts(List<int[]> points, List<Rectangle> rects, RandomAccessibleInterval<T> rai, BatchCallback callback)
throws IOException, RuntimeException, InterruptedException {
return samj.processBatchOfPrompts(points, rects, rai, !onlyBiggest, callback);
}

/**
* Get a 2D segmentation/annotation using two lists of points as the prompts.
* @param listOfPoints2D
Expand Down
51 changes: 40 additions & 11 deletions src/main/java/ai/nets/samj/gui/MainGUI.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package ai.nets.samj.gui;

import ai.nets.samj.annotation.Mask;
import ai.nets.samj.communication.model.EfficientSAM;
import ai.nets.samj.communication.model.EfficientViTSAML2;
import ai.nets.samj.communication.model.SAM2Large;
Expand All @@ -10,6 +11,7 @@
import ai.nets.samj.gui.ModelSelection.ModelSelectionListener;
import ai.nets.samj.gui.components.ModelDrawerPanel;
import ai.nets.samj.gui.components.ModelDrawerPanel.ModelDrawerPanelListener;
import ai.nets.samj.models.AbstractSamJ.BatchCallback;
import ai.nets.samj.ui.ConsumerInterface;
import ai.nets.samj.utils.Constants;
import net.imglib2.RandomAccessibleInterval;
Expand Down Expand Up @@ -37,6 +39,7 @@ public class MainGUI extends JFrame {
private ImageSelectionListener imageListener;
private ModelSelectionListener modelListener;
private ModelDrawerPanelListener modelDrawerListener;
private BatchCallback batchDrawerCallback;
private ConsumerInterface consumer;

private JCheckBox chkRoiManager = new JCheckBox("Add to RoiManager", true);
Expand Down Expand Up @@ -101,13 +104,7 @@ public MainGUI(List<SAMModel> modelList, ConsumerInterface consumer) {
chkInstant.addActionListener(e -> setInstantPromptsEnabled(this.chkInstant.isSelected()));
chkRoiManager.addActionListener(e -> consumer.enableAddingToRoiManager(chkRoiManager.isSelected()));
retunLargest.addActionListener(e -> cmbModels.getSelectedModel().setReturnOnlyBiggest(retunLargest.isSelected()));
btnBatchSAMize.addActionListener(e -> {
try {
batchSAMize();
} catch (IOException | RuntimeException | InterruptedException e1) {
e1.printStackTrace();
}
});
btnBatchSAMize.addActionListener(e -> batchSAMize());
stopProgressBtn.addActionListener(e -> {
// TODO stopProgress();
});
Expand Down Expand Up @@ -461,10 +458,12 @@ private void toggleDrawer() {
repaint();
}

private < T extends RealType< T > & NativeType< T > > void batchSAMize() throws IOException, RuntimeException, InterruptedException {
RandomAccessibleInterval<T> rai = null;
private < T extends RealType< T > & NativeType< T > > void batchSAMize() {
RandomAccessibleInterval<T> rai;
if (this.consumer.getFocusedImage() != this.cmbImages.getSelectedObject())
rai = this.consumer.getFocusedImageAsRai();
else
rai = null;
List<int[]> pointPrompts = this.consumer.getPointRoisOnFocusImage();
List<Rectangle> rectPrompts = this.consumer.getRectRoisOnFocusImage();
if (pointPrompts.size() == 0 && rectPrompts.size() == 0 && rai == null){
Expand All @@ -475,8 +474,14 @@ private < T extends RealType< T > & NativeType< T > > void batchSAMize() throws
return;
}
this.stopProgressBtn.setEnabled(true);
this.consumer.addPolygonsFromGUI(this.cmbModels.getSelectedModel().processBatchOfPrompts(pointPrompts, rectPrompts, rai));
this.stopProgressBtn.setEnabled(false);
new Thread(() -> {
try {
cmbModels.getSelectedModel().processBatchOfPrompts(pointPrompts, rectPrompts, rai, batchDrawerCallback);
} catch (IOException | RuntimeException | InterruptedException e) {
e.printStackTrace();
}
SwingUtilities.invokeLater(() -> stopProgressBtn.setEnabled(false));
}).start();;
pointPrompts.stream().forEach(pp -> consumer.deletePointRoi(pp));
rectPrompts.stream().forEach(pp -> consumer.deleteRectRoi(pp));
}
Expand Down Expand Up @@ -527,6 +532,30 @@ public void setGUIEnabled(boolean enabled) {
}
}
};

batchDrawerCallback = new BatchCallback() {
private int nRois;

@Override
public void setTotalNumberOfRois(int nRois) {
this.nRois = nRois;
SwingUtilities.invokeLater(() -> {
batchProgress.setValue(0);
});
}

@Override
public void updateProgress(int n) {
SwingUtilities.invokeLater(() -> batchProgress.setValue((int) Math.round(100 * n / (double) nRois) ));
}

@Override
public void drawRoi(List<Mask> masks) {
// TODO Auto-generated method stub

}

};
}

public static void main(String[] args) {
Expand Down
71 changes: 67 additions & 4 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ public abstract class AbstractSamJ implements AutoCloseable {
protected static long ENCODE_MARGIN = 64;

protected static int MAX_IMG_SIZE = 2024;

protected static String UPDATE_ID_N_CONTOURS = UUID.randomUUID().toString();

protected static String UPDATE_ID_CONTOUR = UUID.randomUUID().toString();

/** Essentially, a syntactic-shortcut for a String consumer */
public interface BatchCallback {

void setTotalNumberOfRois(int nRois);

void updateProgress(int n);

void drawRoi(List<Mask> masks);
}

/** Essentially, a syntactic-shortcut for a String consumer */
public interface DebugTextPrinter { void printText(String text); }
Expand Down Expand Up @@ -154,6 +168,8 @@ public interface DebugTextPrinter { void printText(String text); }
*/
protected boolean imageSmall = true;

private int nRoisProcessed;

/**
* List of encodings that are cached to avoid recalculating
*/
Expand Down Expand Up @@ -361,17 +377,22 @@ else if (cropSize.length == 3 && cropSize[2] != 3)

}

private List<Mask> processAndRetrieveContours(HashMap<String, Object> inputs, Object callback)
private List<Mask> processAndRetrieveContours(HashMap<String, Object> inputs, BatchCallback callback)
throws IOException, RuntimeException, InterruptedException {
Map<String, Object> results = null;
try {
Task task = python.task(script, inputs);
nRoisProcessed = 1;
task.listen(event -> {
switch (event.responseType) {
case UPDATE:
if (!task.message.equals("new input"))
if (!task.message.equals(UPDATE_ID_CONTOUR) && !task.message.equals(UPDATE_ID_N_CONTOURS))
break;
Object numer = task.outputs;
else if (task.message.equals(UPDATE_ID_CONTOUR)) {
callback.updateProgress(nRoisProcessed ++);
} else if (task.message.equals(UPDATE_ID_N_CONTOURS)) {
callback.setTotalNumberOfRois(Integer.parseInt((String) task.outputs.get("n")));
}
break;
default:
break;
Expand All @@ -392,6 +413,7 @@ else if (task.outputs.get("contours_y") == null)
throw new RuntimeException();
else if (task.outputs.get("rle") == null)
throw new RuntimeException();
callback.updateProgress(Integer.parseInt((String) task.outputs.get("n")));
results = task.outputs;
} catch (IOException | InterruptedException | RuntimeException e) {
try {
Expand Down Expand Up @@ -461,6 +483,47 @@ else if (task.outputs.get("rle") == null)
return masks;
}

public <T extends RealType<T> & NativeType<T>>
List<Mask> processBatchOfPrompts(List<int[]> points, List<Rectangle> rects, RandomAccessibleInterval<T> rai, BatchCallback callback)
throws IOException, RuntimeException, InterruptedException {
return processBatchOfPrompts(points, rects, rai, true, callback);
}

public <T extends RealType<T> & NativeType<T>>
List<Mask> processBatchOfPrompts(List<int[]> pointsList, List<Rectangle> rects,
RandomAccessibleInterval<T> rai, boolean returnAll, BatchCallback callback)
throws IOException, RuntimeException, InterruptedException {
if ((pointsList == null || pointsList.size() == 0) && (rects == null || rects.size() == 0) && (rai == null))
return new ArrayList<Mask>();
checkPrompts(pointsList, rects, rai);

// TODO adapt to reencoding for big images, ideally it should process points close together together
pointsList = adaptPointPrompts(pointsList);
// TODO adapt rect prompts
this.script = "";
SharedMemoryArray maskShma = null;
if (rai != null)
maskShma = SharedMemoryArray.createSHMAFromRAI(rai, false, false);

try {
HashMap<String, Object> inputs = new HashMap<String, Object>();
inputs.put("point_prompts", pointsList == null ? new ArrayList<int[]>() : pointsList);
List<int[]> rectPrompts = new ArrayList<int[]>();
if (rects != null && rects.size() > 0)
rectPrompts = rects.stream().map(rr -> new int[] {rr.x, rr.y, rr.x + rr.width, rr.y + rr.height})
.collect(Collectors.toList());
inputs.put("rect_prompts", rectPrompts);
processPromptsBatchWithSAM(maskShma, returnAll);
printScript(script, "Batch of prompts inference");
List<Mask> polys = processAndRetrieveContours(inputs, callback);
return polys;
} catch (IOException | RuntimeException | InterruptedException ex) {
if (maskShma != null)
maskShma.close();
throw ex;
}
}

public <T extends RealType<T> & NativeType<T>>
List<Mask> processBatchOfPrompts(List<int[]> points, List<Rectangle> rects, RandomAccessibleInterval<T> rai)
throws IOException, RuntimeException, InterruptedException {
Expand Down Expand Up @@ -492,7 +555,7 @@ List<Mask> processBatchOfPrompts(List<int[]> pointsList, List<Rectangle> rects,
inputs.put("rect_prompts", rectPrompts);
processPromptsBatchWithSAM(maskShma, returnAll);
printScript(script, "Batch of prompts inference");
List<Mask> polys = processAndRetrieveContours(inputs, null);
List<Mask> polys = processAndRetrieveContours(inputs);
recalculatePolys(polys, encodeCoords);
return polys;
} catch (IOException | RuntimeException | InterruptedException ex) {
Expand Down
Loading

0 comments on commit 54a9d6e

Please sign in to comment.