From 9936be39aecfeac6960035b02ac4c0aaea2e2141 Mon Sep 17 00:00:00 2001 From: carlosuc3m <100329787@alumnos.uc3m.es> Date: Mon, 2 Dec 2024 14:55:05 +0100 Subject: [PATCH] keep iterating --- src/main/java/ai/nets/samj/gui/MainGUI.java | 36 ++++- .../ai/nets/samj/gui/ResizableButton.java | 132 ++++++++++++++++++ .../gui/components/ProgressBarAndButton.java | 132 ------------------ .../ai/nets/samj/models/AbstractSamJ.java | 57 +++++++- src/main/java/ai/nets/samj/models/Sam2.java | 3 + 5 files changed, 220 insertions(+), 140 deletions(-) create mode 100644 src/main/java/ai/nets/samj/gui/ResizableButton.java delete mode 100644 src/main/java/ai/nets/samj/gui/components/ProgressBarAndButton.java diff --git a/src/main/java/ai/nets/samj/gui/MainGUI.java b/src/main/java/ai/nets/samj/gui/MainGUI.java index afd65d6..ccafb95 100644 --- a/src/main/java/ai/nets/samj/gui/MainGUI.java +++ b/src/main/java/ai/nets/samj/gui/MainGUI.java @@ -10,7 +10,6 @@ import ai.nets.samj.gui.ModelSelection.ModelSelectionListener; import ai.nets.samj.gui.components.ModelDrawerPanel; import ai.nets.samj.gui.components.ModelDrawerPanel.ModelDrawerPanelListener; -import ai.nets.samj.gui.components.ProgressBarAndButton; import ai.nets.samj.ui.ConsumerInterface; import ai.nets.samj.utils.Constants; import net.imglib2.RandomAccessibleInterval; @@ -50,7 +49,8 @@ public class MainGUI extends JFrame { private JButton export = new JButton("Export..."); private JRadioButton radioButton1; private JRadioButton radioButton2; - private ProgressBarAndButton batchProgress = new ProgressBarAndButton("Stop"); + private JProgressBar batchProgress = new JProgressBar(); + private ResizableButton stopProgressBtn = new ResizableButton("■", 10, 2, 2); private final ModelSelection cmbModels; private final ImageSelection cmbImages; private ModelDrawerPanel drawerPanel; @@ -108,6 +108,9 @@ public MainGUI(List modelList, ConsumerInterface consumer) { e1.printStackTrace(); } }); + stopProgressBtn.addActionListener(e -> { + // TODO stopProgress(); + }); close.addActionListener(e -> dispose()); help.addActionListener(e -> consumer.exportImageLabeling()); @@ -181,6 +184,9 @@ private void setTwoThirdsEnabled(boolean enabled) { this.export.setEnabled(enabled); this.radioButton1.setEnabled(enabled); this.radioButton2.setEnabled(enabled); + this.batchProgress.setEnabled(enabled); + if (!enabled) + this.stopProgressBtn.setEnabled(enabled); } private void loadModel() { @@ -346,7 +352,7 @@ private JPanel createSecondComponent() { gbc0.gridy = 0; gbc0.anchor = GridBagConstraints.NORTH; gbc0.fill = GridBagConstraints.NONE; - gbc0.weighty = 0.2; + gbc0.weighty = 0.1; gbc0.insets = new Insets(0, 2, 5, 2); gbc0.weightx = 1; card2.add(new JLabel(ROIM_STR), gbc0); @@ -354,13 +360,27 @@ private JPanel createSecondComponent() { gbc0.gridy = 1; gbc0.anchor = GridBagConstraints.CENTER; gbc0.fill = GridBagConstraints.BOTH; - gbc0.weighty = 0.6; + gbc0.weighty = 0.8; card2.add(btnBatchSAMize, gbc0); gbc0.gridy = 2; - gbc0.weighty = 0.2; - gbc0.fill = GridBagConstraints.HORIZONTAL; - card2.add(batchProgress, gbc0); + gbc0.weighty = 0.1; + gbc0.anchor = GridBagConstraints.CENTER; + gbc0.fill = GridBagConstraints.BOTH; + JPanel wrapper = new JPanel(new GridBagLayout()); + GridBagConstraints gbc1 = new GridBagConstraints(); + gbc1.insets = new Insets(0, 0, 0, 0); + gbc1.gridy = 0; + gbc1.gridx = 0; + gbc1.anchor = GridBagConstraints.CENTER; + gbc1.fill = GridBagConstraints.BOTH; + gbc1.weighty = 1; + gbc1.weightx = 0.9; + wrapper.add(this.batchProgress, gbc1); + gbc1.gridx = 1; + gbc1.weightx = 0.1; + wrapper.add(stopProgressBtn, gbc1); + card2.add(wrapper, gbc0); cardPanel.add(card1, MANUAL_STR); cardPanel.add(card2, PRESET_STR); @@ -451,7 +471,9 @@ private < T extends RealType< T > & NativeType< T > > void batchSAMize() throws // TODO add label that is displayed when there are no prompts selected return; } + this.stopProgressBtn.setEnabled(true); this.consumer.addPolygonsFromGUI(this.cmbModels.getSelectedModel().processBatchOfPrompts(pointPrompts, rectPrompts, rai)); + this.stopProgressBtn.setEnabled(false); pointPrompts.stream().forEach(pp -> consumer.deletePointRoi(pp)); rectPrompts.stream().forEach(pp -> consumer.deleteRectRoi(pp)); } diff --git a/src/main/java/ai/nets/samj/gui/ResizableButton.java b/src/main/java/ai/nets/samj/gui/ResizableButton.java new file mode 100644 index 0000000..d02988f --- /dev/null +++ b/src/main/java/ai/nets/samj/gui/ResizableButton.java @@ -0,0 +1,132 @@ +/*- + * #%L + * Library to call models of the family of SAM (Segment Anything Model) from Java + * %% + * Copyright (C) 2024 SAMJ developers. + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package ai.nets.samj.gui; + +import java.awt.Font; +import java.awt.FontMetrics; +import java.awt.Insets; +import java.awt.event.ComponentAdapter; +import java.awt.event.ComponentEvent; + +import javax.swing.JButton; +import javax.swing.JLabel; +import javax.swing.SwingConstants; +/** + * TODO make it resizable + * This class is a normal button whose font is resized when the component is resized + * @author Carlos Garcia + */ +public class ResizableButton extends JButton { + /** + * Serial version unique identifier + */ + private static final long serialVersionUID = -958367053852506146L; + /** + * Label containing the text that the button displays when it is not pressed + */ + private JLabel textLabel; + /** + * Label containing the loading animation to be displayed while the button is pressed + */ + private int fontSize; + + private int horizontalInset; + private int verticalInset; + + /** + * Constructor. Creates a button that has an icon inside. The icon changes when pressed. + * @param text + * the text inside the button + * @param filePath + * the path to the file that contains the image that is going to be used + * @param filename + * the name of the file that is going to be used + * @param animationSize + * size of the side of the squared animation inside the button + */ + public ResizableButton(String text, int fontSize, int horizontalInset, int verticalInset) { + super(text); + this.horizontalInset = horizontalInset; + this.verticalInset = verticalInset; + setFont(getFont().deriveFont((float) fontSize)); + setMargin(new Insets(verticalInset, horizontalInset, verticalInset, horizontalInset)); + // Add a ComponentListener to the button to adjust font size + addComponentListener(new ComponentAdapter() { + @Override + public void componentResized(ComponentEvent e) { + //adjustButtonFont(); + } + }); + //adjustButtonFont(); + } + + public void setFontSize(int fontSize) { + setFont(getFont().deriveFont((float) fontSize)); + } + + @Override + public Insets getInsets() { + return new Insets(verticalInset, horizontalInset, verticalInset, horizontalInset); + } + + @Override + public Insets getInsets(Insets insets) { + return new Insets(verticalInset, horizontalInset, verticalInset, horizontalInset); + } + + // Method to adjust the font size based on button size + private void adjustButtonFont() { + int btnHeight = this.getHeight(); + int btnWidth = this.getWidth(); + + if (btnHeight <= 0 || btnWidth <= 0) { + return; // Cannot calculate font size with non-positive dimensions + } + + // Get the button's insets + Insets insets = this.getInsets(); + int availableWidth = btnWidth - insets.left - insets.right; + + // Start with a font size based on button height + int fontSize = btnHeight - insets.top - insets.bottom; + + // Get the current font + Font originalFont = this.getFont(); + Font font = originalFont.deriveFont((float) fontSize); + + FontMetrics fm = this.getFontMetrics(font); + int textWidth = fm.stringWidth(this.getText()); + + // Reduce font size until text fits + while (textWidth > availableWidth && fontSize > 0) { + fontSize--; + font = originalFont.deriveFont((float) fontSize); + fm = this.getFontMetrics(font); + textWidth = fm.stringWidth(this.getText()); + } + + // Apply the new font + this.setFont(font); + + // Center the text + this.setHorizontalAlignment(JButton.CENTER); + this.setVerticalAlignment(JButton.CENTER); + } +} diff --git a/src/main/java/ai/nets/samj/gui/components/ProgressBarAndButton.java b/src/main/java/ai/nets/samj/gui/components/ProgressBarAndButton.java deleted file mode 100644 index ac2953f..0000000 --- a/src/main/java/ai/nets/samj/gui/components/ProgressBarAndButton.java +++ /dev/null @@ -1,132 +0,0 @@ -package ai.nets.samj.gui.components; - -import java.awt.Font; -import java.awt.FontMetrics; -import java.awt.GridBagConstraints; -import java.awt.GridBagLayout; -import java.awt.Insets; -import java.awt.event.ComponentAdapter; -import java.awt.event.ComponentEvent; -import javax.swing.JButton; -import javax.swing.JPanel; -import javax.swing.JProgressBar; - - -public class ProgressBarAndButton extends JPanel { - - private static final long serialVersionUID = 2478618937640492286L; - - protected final JProgressBar progressBar = new JProgressBar(); - protected final JButton btn; - private static final double RATIO_CBX_BTN = 2.5; - - public ProgressBarAndButton(String btLabel) { - this.btn = new JButton(btLabel); - btn.setMargin(new Insets(2, 3, 2, 2)); - - // Use GridBagLayout instead of null layout - setLayout(new GridBagLayout()); - - GridBagConstraints gbc = new GridBagConstraints(); - gbc.insets = new Insets(0, 0, 0, 0); // Adjust insets as needed - gbc.fill = GridBagConstraints.BOTH; - gbc.gridy = 0; - - // Add the JComboBox with weightx corresponding to RATIO_CBX_BTN - gbc.gridx = 0; - gbc.weightx = RATIO_CBX_BTN; - gbc.weighty = 1; - add(progressBar, gbc); - - // Add the JButton with weightx of 1 - gbc.gridx = 1; - gbc.weightx = 1.0; - add(btn, gbc); - - // Add a ComponentListener to the button to adjust font size - btn.addComponentListener(new ComponentAdapter() { - @Override - public void componentResized(ComponentEvent e) { - adjustButtonFont(); - } - }); - } - - public void setProgressMax(int maxProgressVal) { - this.progressBar.setMaximum(maxProgressVal); - } - - public void setProgress(int progress) { - this.progressBar.setValue(progress); - } - - @Override - public void doLayout() { - int inset = 2; // Separation between components and edges - int totalInsets = inset * 3; // Left, middle, and right insets - - int width = getWidth(); - int height = getHeight(); - - int availableWidth = width - totalInsets; - double ratioSum = RATIO_CBX_BTN + 1; - - // Calculate widths based on the ratio - int comboWidth = (int) Math.round(availableWidth * RATIO_CBX_BTN / ratioSum); - int btnWidth = availableWidth - comboWidth; - - int x = inset; - int y = 0; - int componentHeight = height; // Account for top and bottom insets - - // Set bounds for the JComboBox - progressBar.setBounds(x, y, comboWidth, componentHeight); - - x += comboWidth + inset; // Move x position for the JButton - - // Set bounds for the JButton - btn.setBounds(x, y, btnWidth, componentHeight); - - // Adjust font size after layout - adjustButtonFont(); - } - - // Method to adjust the font size based on button size - private void adjustButtonFont() { - int btnHeight = btn.getHeight(); - int btnWidth = btn.getWidth(); - - if (btnHeight <= 0 || btnWidth <= 0) { - return; // Cannot calculate font size with non-positive dimensions - } - - // Get the button's insets - Insets insets = btn.getInsets(); - int availableWidth = btnWidth - insets.left - insets.right; - - // Start with a font size based on button height - int fontSize = btnHeight - insets.top - insets.bottom;// - 4; // Subtract padding - - // Get the current font - Font originalFont = btn.getFont(); - Font font = originalFont.deriveFont((float) fontSize); - - FontMetrics fm = btn.getFontMetrics(font); - int textWidth = fm.stringWidth(btn.getText()); - - // Reduce font size until text fits - while (textWidth > availableWidth && fontSize > 0) { - fontSize--; - font = originalFont.deriveFont((float) fontSize); - fm = btn.getFontMetrics(font); - textWidth = fm.stringWidth(btn.getText()); - } - - // Apply the new font - btn.setFont(font); - - // Center the text - btn.setHorizontalAlignment(JButton.CENTER); - btn.setVerticalAlignment(JButton.CENTER); - } -} diff --git a/src/main/java/ai/nets/samj/models/AbstractSamJ.java b/src/main/java/ai/nets/samj/models/AbstractSamJ.java index 71c4749..ab1755f 100644 --- a/src/main/java/ai/nets/samj/models/AbstractSamJ.java +++ b/src/main/java/ai/nets/samj/models/AbstractSamJ.java @@ -361,6 +361,61 @@ else if (cropSize.length == 3 && cropSize[2] != 3) } + private List processAndRetrieveContours(HashMap inputs, Object callback) + throws IOException, RuntimeException, InterruptedException { + Map results = null; + try { + Task task = python.task(script, inputs); + task.listen(event -> { + switch (event.responseType) { + case UPDATE: + if (!task.message.equals("new input")) + break; + Object numer = task.outputs; + break; + default: + break; + } + }); + task.waitFor(); + if (task.status == TaskStatus.CANCELED) + throw new RuntimeException(); + else if (task.status == TaskStatus.FAILED) + throw new RuntimeException(); + else if (task.status == TaskStatus.CRASHED) + throw new RuntimeException(); + else if (task.status != TaskStatus.COMPLETE) + throw new RuntimeException(); + else if (task.outputs.get("contours_x") == null) + throw new RuntimeException(); + else if (task.outputs.get("contours_y") == null) + throw new RuntimeException(); + else if (task.outputs.get("rle") == null) + throw new RuntimeException(); + results = task.outputs; + } catch (IOException | InterruptedException | RuntimeException e) { + try { + this.shma.close(); + } catch (IOException e1) { + throw new IOException(e.toString() + System.lineSeparator() + e1.toString()); + } + throw e; + } + + final List> contours_x_container = (List>)results.get("contours_x"); + final Iterator> contours_x = contours_x_container.iterator(); + final Iterator> contours_y = ((List>)results.get("contours_y")).iterator(); + final Iterator> rles = ((List>)results.get("rle")).iterator(); + final List masks = new ArrayList(contours_x_container.size()); + while (contours_x.hasNext()) { + int[] xArr = contours_x.next().stream().mapToInt(Number::intValue).toArray(); + int[] yArr = contours_y.next().stream().mapToInt(Number::intValue).toArray(); + long[] rle = rles.next().stream().mapToLong(Number::longValue).toArray(); + masks.add(Mask.build(new Polygon(xArr, yArr, xArr.length), rle)); + } + return masks; + } + @SuppressWarnings("unchecked") private List processAndRetrieveContours(HashMap inputs) throws IOException, RuntimeException, InterruptedException { @@ -437,7 +492,7 @@ List processBatchOfPrompts(List pointsList, List rects, inputs.put("rect_prompts", rectPrompts); processPromptsBatchWithSAM(maskShma, returnAll); printScript(script, "Batch of prompts inference"); - List polys = processAndRetrieveContours(inputs); + List polys = processAndRetrieveContours(inputs, null); recalculatePolys(polys, encodeCoords); return polys; } catch (IOException | RuntimeException | InterruptedException ex) { diff --git a/src/main/java/ai/nets/samj/models/Sam2.java b/src/main/java/ai/nets/samj/models/Sam2.java index dc4c2b6..da08dee 100644 --- a/src/main/java/ai/nets/samj/models/Sam2.java +++ b/src/main/java/ai/nets/samj/models/Sam2.java @@ -461,6 +461,9 @@ protected void processPromptsBatchWithSAM(SharedMemoryArray shmArr, boolean retu + "contours_x = []" + System.lineSeparator() + "contours_y = []" + System.lineSeparator() + "rle_masks = []" + System.lineSeparator() + + "task.update('new input')" + System.lineSeparator() + + "args = {\"outputs\": {'test': contours_x}}" + System.lineSeparator() + + "task._respond(ResponseType.UPDATE, args)" + System.lineSeparator() // TODO right now is geetting the mask after each prompt // TODO test processing first every prompt and then getting the masks + "" + System.lineSeparator()