diff --git a/src/main/java/ai/nets/samj/ij/ui/Consumer.java b/src/main/java/ai/nets/samj/ij/ui/Consumer.java index 9976617..30275f7 100644 --- a/src/main/java/ai/nets/samj/ij/ui/Consumer.java +++ b/src/main/java/ai/nets/samj/ij/ui/Consumer.java @@ -1,8 +1,13 @@ package ai.nets.samj.ij.ui; +import java.awt.Color; import java.awt.Polygon; +import java.awt.Rectangle; +import java.awt.event.MouseEvent; +import java.awt.event.MouseListener; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Stack; import java.util.stream.Collectors; @@ -11,21 +16,34 @@ import ai.nets.samj.annotation.Mask; import ai.nets.samj.gui.components.ComboBoxItem; +import ai.nets.samj.models.AbstractSamJ; import ai.nets.samj.ui.ConsumerInterface; import ij.IJ; import ij.ImagePlus; import ij.Prefs; +import ij.WindowManager; import ij.gui.ImageCanvas; import ij.gui.ImageWindow; +import ij.gui.Overlay; import ij.gui.PolygonRoi; import ij.gui.Roi; +import ij.gui.Toolbar; +import ij.plugin.OverlayLabels; import ij.plugin.frame.RoiManager; +import io.bioimage.modelrunner.system.PlatformDetection; +import net.imglib2.FinalInterval; +import net.imglib2.Interval; import net.imglib2.Localizable; +import net.imglib2.Point; import net.imglib2.RandomAccessibleInterval; import net.imglib2.img.display.imagej.ImageJFunctions; import net.imglib2.type.numeric.integer.UnsignedShortType; -public class Consumer extends ConsumerInterface { +/** + * + * @author Carlos Garcia Lopez de Haro + */ +public class Consumer extends ConsumerInterface implements MouseListener { /** * The image being processed */ @@ -46,6 +64,10 @@ public class Consumer extends ConsumerInterface { * Whether to add the ROIs created to the ROI manager or not */ private boolean isAddingToRoiManager = true; + /** + * Counter of the ROIs created + */ + private int promptsCreatedCnt = 0; /** * A list to save several ROIs that are being created for the same prompt. * Whenever the prompt is sent to the model, this list is emptied @@ -101,9 +123,15 @@ public class Consumer extends ConsumerInterface { private boolean registered = false; @Override + /** + * {@inheritDoc} + * + * GEt the list of open images in ImageJ + */ public List getListOfOpenImages() { - // TODO Auto-generated method stub - return null; + return Arrays.stream(WindowManager.getImageTitles()) + .map(title -> new IJComboBoxItem(WindowManager.getImage(title).getID(), (Object) WindowManager.getImage(title))) + .collect(Collectors.toList()); } @Override @@ -132,7 +160,6 @@ public void exportImageLabeling() { public void activateListeners() { if (registered) return; SwingUtilities.invokeLater(() -> { - IJ.addEventListener(this); activeCanvas.removeKeyListener(IJ.getInstance()); activeWindow.removeKeyListener(IJ.getInstance()); activeCanvas.addMouseListener(this); @@ -147,7 +174,6 @@ public void activateListeners() { public void deactivateListeners() { if (!registered) return; SwingUtilities.invokeLater(() -> { - IJ.removeEventListener(this); activeCanvas.removeMouseListener(this); activeCanvas.removeKeyListener(this); activeWindow.removeWindowListener(this); @@ -177,4 +203,211 @@ private RoiManager startRoiManager() { return roiManager; } + @Override + public void mouseReleased(MouseEvent e) { + if (activeImage.getRoi() == null) + return; + if (Toolbar.getToolName().equals("rectangle")) { + annotateRect(); + } else if (Toolbar.getToolName().equals("point") || Toolbar.getToolName().equals("multipoint")) { + annotatePoints(e); + } else if (Toolbar.getToolName().equals("freeline")) { + annotateBrush(e); + } + if (!isCollectingPoints) activeImage.deleteRoi(); + } + + private void annotateRect() { + final Roi roi = activeImage.getRoi(); + final Rectangle rectBounds = roi.getBounds(); + final Interval rectInterval = new FinalInterval( + new long[] { rectBounds.x, rectBounds.y }, + new long[] { rectBounds.x+rectBounds.width-1, rectBounds.y+rectBounds.height-1 } ); + submitRectPrompt(rectInterval); + } + + private void submitRectPrompt(Interval rectInterval) { + try { + addToRoiManager(this.selectedModel.fetch2dSegmentation(rectInterval), "rect"); + } catch (Exception ex) { + ex.printStackTrace();; + } + } + + private void annotatePoints(MouseEvent e) { + final Roi roi = activeImage.getRoi(); + // TODO think what to do with negative points + if (e.isControlDown() && e.isAltDown() && false) { + roi.setFillColor(Color.red); + //add point to the list only + isCollectingPoints = true; + Iterator iterator = roi.iterator(); + java.awt.Point p = iterator.next(); + while (iterator.hasNext()) p = iterator.next(); + collecteNegPoints.add( new Point(p.x,p.y) ); //NB: add ImgLib2 Point + //TODO log.info("Image window: collecting points..., already we have: "+collectedPoints.size()); + } else if ((e.isControlDown() && !PlatformDetection.isMacOS()) || (e.isMetaDown() && PlatformDetection.isMacOS())) { + //add point to the list only + isCollectingPoints = true; + Iterator iterator = roi.iterator(); + java.awt.Point p = iterator.next(); + while (iterator.hasNext()) p = iterator.next(); + collectedPoints.add( new Point(p.x,p.y) ); //NB: add ImgLib2 Point + //TODO log.info("Image window: collecting points..., already we have: "+collectedPoints.size()); + } else { + isCollectingPoints = false; + //collect this last one + Iterator iterator = roi.iterator(); + java.awt.Point p = iterator.next(); + while (iterator.hasNext()) p = iterator.next(); + collectedPoints.add( new Point(p.x,p.y) ); + submitAndClearPoints(); + } + } + + /** + * Send the point prompts to SAM and clear the lists collecting them + */ + private void submitAndClearPoints() { + if (this.selectedModel == null) return; + if (collectedPoints.size() == 0) return; + + //TODO log.info("Image window: Processing now points, this count: "+collectedPoints.size()); + isCollectingPoints = false; + activeImage.deleteRoi(); + Rectangle zoomedRectangle = this.activeCanvas.getSrcRect(); + try { + if (activeImage.getWidth() * activeImage.getHeight() > Math.pow(AbstractSamJ.MAX_ENCODED_AREA_RS, 2) + || activeImage.getWidth() > AbstractSamJ.MAX_ENCODED_SIDE || activeImage.getHeight() > AbstractSamJ.MAX_ENCODED_SIDE) + addToRoiManager(selectedModel.fetch2dSegmentation(collectedPoints, collecteNegPoints, zoomedRectangle), + (collectedPoints.size() > 1 ? "points" : "point")); + else + addToRoiManager(selectedModel.fetch2dSegmentation(collectedPoints, collecteNegPoints), + (collectedPoints.size() > 1 ? "points" : "point")); + } catch (Exception ex) { + ex.printStackTrace(); + } + collectedPoints = new ArrayList(); + collecteNegPoints = new ArrayList(); + temporalROIs = new ArrayList(); + temporalNegROIs = new ArrayList(); + } + + private void annotateBrush(MouseEvent e) { + final Roi roi = activeImage.getRoi(); + // TODO this is not a real mask prompt, it is just taking + // TODO all the points in a line and using them, modify it for a true mask + if (e.isControlDown() && e.isAltDown()) { + temporalNegROIs.add(roi); + roi.setStrokeColor(Color.red); + isCollectingPoints = true; + Iterator it = roi.iterator(); + while (it.hasNext()) { + java.awt.Point p = it.next(); + collecteNegPoints.add(new Point(p.x,p.y)); + } + addTemporalRois(); + } else if (e.isControlDown()) { + temporalROIs.add(roi); + isCollectingPoints = true; + Iterator it = roi.iterator(); + while (it.hasNext()) { + java.awt.Point p = it.next(); + collectedPoints.add(new Point(p.x,p.y)); + } + addTemporalRois(); + } else { + isCollectingPoints = false; + Rectangle rect = roi.getBounds(); + if (rect.height == 1) { + for (int i = 0; i < rect.width; i ++) { + collectedPoints.add(new Point(rect.x + i, rect.y)); + } + } else if (rect.width == 1) { + for (int i = 0; i < rect.height; i ++) { + collectedPoints.add(new Point(rect.x, rect.y + i)); + } + } else { + Iterator it = roi.iterator(); + while (it.hasNext()) { + java.awt.Point p = it.next(); + collectedPoints.add(new Point(p.x,p.y)); + } + } + // TODO move this logic to SAMJ into the masks option + if (collectedPoints.size() > 1 && collectedPoints.size() < 6) + collectedPoints = Arrays.asList(new Localizable[] {collectedPoints.get(1)}); + else if (collectedPoints.size() > 1 && collectedPoints.size() < 50) { + List newCollectedPoints = new ArrayList(); + while (newCollectedPoints.size() == 0) { + for (Localizable pp : collectedPoints) { + if (Math.random() < 0.2) newCollectedPoints.add(pp); + } + } + collectedPoints = newCollectedPoints; + } else if (collectedPoints.size() > 50) { + List newCollectedPoints = new ArrayList(); + while (newCollectedPoints.size() < 10) { + for (Localizable pp : collectedPoints) { + if (Math.random() < Math.min(0.1, 50.0 / collectedPoints.size())) newCollectedPoints.add(pp); + } + } + collectedPoints = newCollectedPoints; + } + submitAndClearPoints(); + } + } + + private void addTemporalRois() { + //Overlay overlay = activeCanvas.getOverlay(); + Overlay overlay = OverlayLabels.createOverlay(); + for (Roi rr : this.roiManager.getRoisAsArray()) + overlay.add(rr); + this.temporalROIs.stream().forEach(r -> overlay.add(r)); + this.temporalNegROIs.stream().forEach(r -> overlay.add(r)); + activeCanvas.setShowAllList(overlay); + this.activeImage.draw(); + } + + /** + * Add a single polygon to the ROI manager + * @param pRoi + */ + public void addToRoiManager(final PolygonRoi pRoi ) { + if (isAddingToRoiManager) roiManager.addRoi(pRoi); + } + + /** + * Add the new roi to the ROI manager + * @param polys + * list of polygons that will be converted into polygon ROIs and sent to the ROI manager + * @param promptShape + * String giving information about which prompt was used to generate the ROI + */ + void addToRoiManager(final List polys, final String promptShape) { + this.redoStack.clear(); + this.redoAnnotatedMask.clear(); + promptsCreatedCnt++; + int resNo = 1; + List undoRois = new ArrayList(); + for (Mask m : polys) { + final PolygonRoi pRoi = new PolygonRoi(m.getContour(), PolygonRoi.POLYGON); + pRoi.setName(promptsCreatedCnt + "." + (resNo ++) + "_"+promptShape + "_" + this.selectedModel.getName()); + this.addToRoiManager(pRoi); + undoRois.add(pRoi); + } + this.undoStack.push(undoRois); + this.annotatedMask.push(polys); + } + + // ===== unused events ===== + @Override + public void mouseEntered(MouseEvent e) {} + @Override + public void mouseClicked(MouseEvent e) {} + @Override + public void mousePressed(MouseEvent e) {} + @Override + public void mouseExited(MouseEvent e) {} + }