Skip to content

Commit

Permalink
try to get masks in batchsamize
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Nov 30, 2024
1 parent d501ae7 commit 5dca5cb
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 41 deletions.
40 changes: 14 additions & 26 deletions src/main/java/ai/nets/samj/gui/MainGUI.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@
import ai.nets.samj.communication.model.SAM2Tiny;
import ai.nets.samj.communication.model.SAMModel;
import ai.nets.samj.gui.ImageSelection.ImageSelectionListener;
import ai.nets.samj.gui.ModelSelection.ModelSelctionListener;
import ai.nets.samj.gui.ModelSelection.ModelSelectionListener;
import ai.nets.samj.gui.components.ComboBoxItem;
import ai.nets.samj.gui.components.ModelDrawerPanel;
import ai.nets.samj.gui.components.ModelDrawerPanel.ModelDrawerPanelListener;
import ai.nets.samj.ui.ConsumerInterface;
import ai.nets.samj.utils.Constants;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.util.Cast;

import javax.swing.*;
Expand All @@ -35,15 +37,14 @@ public class MainGUI extends JFrame {
private boolean isDrawerOpen = false;
private final List<SAMModel> modelList;
private ImageSelectionListener imageListener;
private ModelSelctionListener modelListener;
private ModelSelectionListener modelListener;
private ModelDrawerPanelListener modelDrawerListener;
private ConsumerInterface consumer;

private JCheckBox chkRoiManager = new JCheckBox("Add to RoiManager", true);
private JCheckBox retunLargest = new JCheckBox("Only return largest ROI", true);
private JSwitchButton chkInstant = new JSwitchButton("LIVE", "OFF");
private JButton go = new JButton("Go");
private JComboBox<ComboBoxItem> cmbPresets = new JComboBox<ComboBoxItem>();
private JButton btnBatchSAMize = new JButton("Batch SAMize");
private JButton close = new JButton("Close");
private JButton help = new JButton("Help");
Expand Down Expand Up @@ -353,13 +354,7 @@ private JPanel createSecondComponent() {
gbc0.gridy = 1;
gbc0.anchor = GridBagConstraints.CENTER;
gbc0.fill = GridBagConstraints.BOTH;
gbc0.weighty = 0.4;
card2.add(cmbPresets, gbc0);

gbc0.gridy = 2;
gbc0.anchor = GridBagConstraints.CENTER;
gbc0.fill = GridBagConstraints.BOTH;
gbc0.weighty = 0.4;
gbc0.weighty = 0.8;
card2.add(btnBatchSAMize, gbc0);

cardPanel.add(card1, MANUAL_STR);
Expand All @@ -372,7 +367,8 @@ private JPanel createSecondComponent() {
});

radioButton2.addActionListener(e -> {
updatePresetsCard();
CardLayout cl = (CardLayout) (cardPanel.getLayout());
cl.show(cardPanel, PRESET_STR);
});

GridBagConstraints gbc = new GridBagConstraints();
Expand Down Expand Up @@ -440,22 +436,14 @@ private void toggleDrawer() {
}

private < T extends RealType< T > & NativeType< T > > void batchSAMize() throws IOException, RuntimeException, InterruptedException {
RandomAccessibleInterval<T> rai = Cast.unchecked(((ComboBoxItem) cmbPresets.getSelectedItem()).getImageAsImgLib2());
RandomAccessibleInterval<T> rai = this.consumer.getFocusedImageAsRai();
List<Object> prompts = new ArrayList<Object>();
if (prompts.size() == 0 && !(rai.getType() instanceof IntegerType)){
// TODO add label that is displayed when there are no prompts selected
return;
}
this.consumer.addPolygonsFromGUI(this.cmbModels.getSelectedModel().fetch2dSegmentationFromMask(rai));
}

private void updatePresetsCard() {
CardLayout cl = (CardLayout) (cardPanel.getLayout());
cl.show(cardPanel, PRESET_STR);

List<ComboBoxItem> openSeqs = consumer.getListOfOpenImages();
ComboBoxItem[] objects = new ComboBoxItem[openSeqs.size()];
for (int i = 0; i < objects.length; i ++) objects[i] = openSeqs.get(i);
DefaultComboBoxModel<ComboBoxItem> comboBoxModel = new DefaultComboBoxModel<ComboBoxItem>(objects);
this.cmbPresets.setModel(comboBoxModel);

btnBatchSAMize.setEnabled(openSeqs.size() != 0);
}

private void createListeners() {
imageListener = new ImageSelectionListener() {
Expand All @@ -472,7 +460,7 @@ public void imageActionsOnImageChanged() {
go.setEnabled(cmbImages.getSelectedObject() != null);
}
};
modelListener = new ModelSelctionListener() {
modelListener = new ModelSelectionListener() {

@Override
public void changeDrawerPanel() {
Expand Down
8 changes: 4 additions & 4 deletions src/main/java/ai/nets/samj/gui/ModelSelection.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ public class ModelSelection extends ComboBoxButtonComp<String> implements PopupM

private SAMModel selected;

private final ModelSelctionListener listener;
private final ModelSelectionListener listener;

private final List<SAMModel> models;


private static final long serialVersionUID = 2478618937640492286L;

private ModelSelection(List<SAMModel> models, ModelSelctionListener listener) {
private ModelSelection(List<SAMModel> models, ModelSelectionListener listener) {
super(new JComboBox<String>());
this.listener = listener;
this.models = models;
Expand All @@ -37,7 +37,7 @@ private ModelSelection(List<SAMModel> models, ModelSelctionListener listener) {
selected = models.get(cmbBox.getSelectedIndex());
}

protected static ModelSelection create(List<SAMModel> models, ModelSelctionListener listener) {
protected static ModelSelection create(List<SAMModel> models, ModelSelectionListener listener) {
return new ModelSelection(models, listener);
}

Expand Down Expand Up @@ -89,7 +89,7 @@ public void popupMenuWillBecomeInvisible(PopupMenuEvent e) {
public void popupMenuCanceled(PopupMenuEvent e) {
}

public interface ModelSelctionListener {
public interface ModelSelectionListener {

void changeDrawerPanel();

Expand Down
6 changes: 4 additions & 2 deletions src/main/java/ai/nets/samj/models/AbstractSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,11 @@ List<Mask> processMask(RandomAccessibleInterval<T> img)
List<Mask> processMask(RandomAccessibleInterval<T> img, boolean returnAll)
throws IOException, RuntimeException, InterruptedException {
long[] dims = img.dimensionsAsLongArray();
if (dims.length == 2 && dims[1] == this.shma.getOriginalShape()[1] && dims[0] == this.shma.getOriginalShape()[0]) {
if ((dims.length == 2 || (dims.length == 3 && dims[2] == 1))
&& dims[1] == this.shma.getOriginalShape()[0] && dims[0] == this.shma.getOriginalShape()[1]) {
img = Views.permute(img, 0, 1);
} else if (dims.length != 2 && dims[0] != this.shma.getOriginalShape()[1] && dims[1] != this.shma.getOriginalShape()[0]) {
} else if (dims[0] != this.shma.getOriginalShape()[0] && dims[1] != this.shma.getOriginalShape()[1]
|| (dims.length == 3 && dims[2] != 1) || dims.length > 3) {
throw new IllegalArgumentException("The provided mask should be a 2d image with just one channel of width "
+ this.shma.getOriginalShape()[1] + " and height " + this.shma.getOriginalShape()[0]);
}
Expand Down
19 changes: 10 additions & 9 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ public class Sam2 extends AbstractSamJ {
+ "globals()['measure'] = measure" + System.lineSeparator()
+ "globals()['np'] = np" + System.lineSeparator()
+ "globals()['torch'] = torch" + System.lineSeparator()
+ "globals()['label'] = label" + System.lineSeparator()
+ "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator()
+ "globals()['predictor'] = predictor" + System.lineSeparator();
/**
Expand Down Expand Up @@ -331,31 +332,31 @@ protected void processMasksWithSam(SharedMemoryArray shmArr, boolean returnAll)
+ " n_points = np.min([3, inds[0].shape[0]])" + System.lineSeparator()
+ " random_positions = np.random.choice(inds[0].shape[0], n_points, replace=False)" + System.lineSeparator()
+ " for pp in range(n_points):" + System.lineSeparator()
+ " point_prompts += [np.array([inds[0][random_posiitons[pp]], inds[1][random_posiitons[pp]]])]" + System.lineSeparator()
+ " point_labels += [np.array(n_feat)]" + System.lineSeparator()
+ " point_prompts += [[inds[0][random_positions[pp]], inds[1][random_positions[pp]]]]" + System.lineSeparator()
+ " point_labels += [n_feat]" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "mask, _, _ = predictor.predict_batch(" + System.lineSeparator()
+ " point_coords_batch=point_prompts," + System.lineSeparator()
+ " point_labels_batch=point_labels," + System.lineSeparator()
+ "mask, _, _ = predictor.predict(" + System.lineSeparator()
+ " point_coords=point_prompts," + System.lineSeparator()
+ " point_labels=point_labels," + System.lineSeparator()
+ " multimask_output=False," + System.lineSeparator()
+ " box_batch=None,)" + System.lineSeparator()
+ " box=None,)" + System.lineSeparator()
+ "contours_x = []" + System.lineSeparator()
+ "contours_y = []" + System.lineSeparator()
+ "rle_masks = []" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "for b in range(num_features):"
+ "for b in range(num_features):" + System.lineSeparator()
+ " mm = mask[b]" + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ (this.isIJROIManager ? " mm += mm[:-1, :-1]" : "") + System.lineSeparator()
+ " c_x, c_y, r_m = get_polygons_from_binary_mask(mm, only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
+ " contours_x += c_x" + System.lineSeparator()
+ " contours_y += c_Y" + System.lineSeparator()
+ " rle_masks += r_m" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "" + System.lineSeparator()
// TODO remove + "import matplotlib.pyplot as plt" + System.lineSeparator()
// TODO remove + "plt.imsave('/tmp/aa.jpg', mask[0], cmap='gray')" + System.lineSeparator()
+ (this.isIJROIManager ? "mask[0, 1:, 1:] += mask[0, :-1, :-1]" : "") + System.lineSeparator()
//+ (this.isIJROIManager ? "mask[0, :, 1:] += mask[0, :, :-1]" : "") + System.lineSeparator()
//+ "np.save('/home/carlos/git/aa.npy', mask)" + System.lineSeparator()
+ "contours_x, contours_y, rle_masks = get_polygons_from_binary_mask(mask[0], only_biggest=" + (!returnAll ? "True" : "False") + ")" + System.lineSeparator()
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/ai/nets/samj/ui/ConsumerInterface.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,15 @@
package ai.nets.samj.ui;

import java.awt.Polygon;
import java.awt.Rectangle;
import java.util.List;

import ai.nets.samj.annotation.Mask;
import ai.nets.samj.communication.model.SAMModel;
import ai.nets.samj.gui.components.ComboBoxItem;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;

/**
* Interface to be implemented by the imaging software that wants to use the default SAMJ UI.
Expand Down Expand Up @@ -65,6 +69,14 @@ public abstract class ConsumerInterface {
* numbered from 1.
*/
public abstract void exportImageLabeling();

public abstract Object getFocusedImage();

public abstract < T extends RealType< T > & NativeType< T > > RandomAccessibleInterval<T> getFocusedImageAsRai();

public abstract List<long[]> getPointRoisOnFocusImage();

public abstract List<Rectangle> getRectRoisOnFocusImage();

public abstract void addPolygonsFromGUI(List<Mask> masks);

Expand Down

0 comments on commit 5dca5cb

Please sign in to comment.