diff --git a/src/main/java/ai/nets/samj/gui/MainGUI.java b/src/main/java/ai/nets/samj/gui/MainGUI.java index 637fa80..c12eaf0 100644 --- a/src/main/java/ai/nets/samj/gui/MainGUI.java +++ b/src/main/java/ai/nets/samj/gui/MainGUI.java @@ -547,8 +547,14 @@ public void modelActionsOnImageChanged() { public void imageActionsOnImageChanged() { consumer.deactivateListeners(); consumer.deselectImage(); - setTwoThirdsEnabled(false); - go.setEnabled(cmbImages.getSelectedObject() != null); + setTwoThirdsEnabled(false); + if (go.isEnabled()) + return; + go.showAnimation(true); + new Thread(() -> { + go.setEnabled(cmbModels.getSelectedModel().isInstalled()); + go.showAnimation(false); + }).start(); } }; modelListener = new ModelSelectionListener() { @@ -564,6 +570,12 @@ public void changeDrawerPanel() { public void changeGUI() { setTwoThirdsEnabled(false); go.setEnabled(cmbImages.getSelectedObject() != null); + go.setEnabled(false); + go.showAnimation(true); + new Thread(() -> { + go.setEnabled(cmbModels.getSelectedModel().isInstalled()); + go.showAnimation(false); + }).start(); } }; modelDrawerListener = new ModelDrawerPanelListener() { diff --git a/src/main/java/ai/nets/samj/install/Sam2EnvManager.java b/src/main/java/ai/nets/samj/install/Sam2EnvManager.java index 526f1b8..1679a2a 100644 --- a/src/main/java/ai/nets/samj/install/Sam2EnvManager.java +++ b/src/main/java/ai/nets/samj/install/Sam2EnvManager.java @@ -57,18 +57,6 @@ public class Sam2EnvManager extends SamEnvManagerAbstract { * Default version for the family of SAM2 models */ final public static String DEFAULT_SAM2 = "tiny"; - /** - * Name of the file that contains the weights of SAM Huge - */ - final public static String SAM2_TINY_WEIGHTS_NAME = "sam2_hiera_tiny.pth"; - /** - * Name of the file that contains the weights of SAM Huge - */ - final public static String SAM2_SMALL_WEIGHTS_NAME = "sam2_hiera_small.pth"; - /** - * Name of the file that contains the weights of SAM Huge - */ - final public static String SAM2_LARGE_WEIGHTS_NAME = "sam2_hiera_large.pth"; /** * Dependencies to be checked to make sure that the environment is able to load a SAM based model. @@ -99,7 +87,7 @@ else if (!PlatformDetection.getArch().equals(PlatformDetection.ARCH_ARM64) && !P INSTALL_PIP_DEPS = Arrays.asList(new String[] {"samv2==0.0.4", "pytest"}); } /** - * Byte sizes of all the EfficientViTSAM options + * Byte sizes of all the SAM2 options */ final public static HashMap SAM2_BYTE_SIZES_MAP; static { @@ -109,6 +97,17 @@ else if (!PlatformDetection.getArch().equals(PlatformDetection.ARCH_ARM64) && !P SAM2_BYTE_SIZES_MAP.put("base_plus", (long) -1); SAM2_BYTE_SIZES_MAP.put("large", (long) 897952466); } + /** + * Byte sizes of all the SAM2.1 options + */ + final public static HashMap SAM2_1_BYTE_SIZES_MAP; + static { + SAM2_1_BYTE_SIZES_MAP = new HashMap(); + SAM2_1_BYTE_SIZES_MAP.put("tiny", (long) 156_008_466); + SAM2_1_BYTE_SIZES_MAP.put("small", (long) 184_416_285); + SAM2_1_BYTE_SIZES_MAP.put("base_plus", (long) 323_606_802); + SAM2_1_BYTE_SIZES_MAP.put("large", (long) 898_083_611); + } /** * Name of the environment that contains the code and weigths to run EfficientSAM models */ @@ -118,9 +117,13 @@ else if (!PlatformDetection.getArch().equals(PlatformDetection.ARCH_ARM64) && !P */ final static public String SAM2_NAME = "sam2"; /** - * URL to download the EfficientSAM model + * URL to download the SAM2 model + */ + final static private String SAM2_URL = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_%s.pt"; + /** + * URL to download the SAM2 model */ - final static public String SAM2_URL = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_%s.pt"; + final static private String SAM2_FNAME = "sam2.1_hiera_%s.pt"; private Sam2EnvManager(String modelType) { List modelTypes = SAM2_BYTE_SIZES_MAP.keySet().stream().collect(Collectors.toList()); @@ -381,7 +384,11 @@ public String getModelEnv() { * @return the official name of the EfficientSAM Small weights */ public String getModelWeigthsName() { - return "sam2_hiera_" + modelType + ".pt"; + try { + return DownloadModel.getFileNameFromURLString(String.format(SAM2_URL, modelType)); + } catch (MalformedURLException e) { + return String.format(SAM2_FNAME, modelType); + } } @Override @@ -390,7 +397,7 @@ public String getModelWeigthPath() { try { file = Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", DownloadModel.getFileNameFromURLString(String.format(SAM2_URL, modelType))).toFile(); } catch (MalformedURLException e) { - file = Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", String.format("sam2_hiera_%s.pt", modelType)).toFile(); + file = Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", String.format(SAM2_FNAME, modelType)).toFile(); } return file.getAbsolutePath();