Skip to content

Commit

Permalink
make sam2 work
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Aug 7, 2024
1 parent 9115272 commit f1b4bfe
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 32 deletions.
15 changes: 7 additions & 8 deletions src/main/java/ai/nets/samj/communication/model/SAM2Tiny.java
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,15 @@ public boolean isInstalled() {
*/
public void setImage(final RandomAccessibleInterval<?> image, final SAMJLogger useThisLoggerForIt)
throws IOException, InterruptedException, RuntimeException {
this.log = useThisLoggerForIt;
AbstractSamJ.DebugTextPrinter filteringLogger = text -> {
int idx = text.indexOf("contours_x");
if (idx > 0) this.log.info( text.substring(0,idx) );
else this.log.info( text );
};
if (this.efficientSamJ == null)
efficientSamJ = Sam2.initializeSam("tiny", manager);
efficientSamJ = Sam2.initializeSam("tiny", manager, filteringLogger, false);
try {
this.log = useThisLoggerForIt;
AbstractSamJ.DebugTextPrinter filteringLogger = text -> {
int idx = text.indexOf("contours_x");
if (idx > 0) this.log.info( text.substring(0,idx) );
else this.log.info( text );
};
this.efficientSamJ.setDebugPrinter(filteringLogger);
this.efficientSamJ.setImage(Cast.unchecked(image));;
} catch (IOException | InterruptedException | RuntimeException e) {
log.error(FULL_NAME + " experienced an error: " + e.getMessage());
Expand Down
1 change: 1 addition & 0 deletions src/main/java/ai/nets/samj/gui/SAMModelPanel.java
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ public SAMModelPanel(SAMModels models, CallParent updateParent) {

ButtonGroup group = new ButtonGroup();
for(SAMModel model : models) {
model.getInstallationManger().setConsumer((str) -> addHtml(str));
JRadioButton rb = new JRadioButton(model.getName(), false);
rbModels.add(rb);
rb.addActionListener(this);
Expand Down
46 changes: 24 additions & 22 deletions src/main/java/ai/nets/samj/install/Sam2EnvManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.net.URL;
import java.nio.file.Paths;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -69,7 +70,7 @@ public class Sam2EnvManager extends SamEnvManagerAbstract {
* General for every supported model.
*/
final public static List<String> CHECK_DEPS = Arrays.asList(new String[] {"appose", "torch=2.4.0",
"torchvision=0.19.0", "skimage, samv2=0.0.4"});
"torchvision=0.19.0", "skimage", "sam2", "pytest"});
/**
* Dependencies that have to be installed in any SAMJ created environment using Mamba or Conda
*/
Expand All @@ -86,9 +87,9 @@ public class Sam2EnvManager extends SamEnvManagerAbstract {
final public static List<String> INSTALL_PIP_DEPS;
static {
if (!PlatformDetection.getArch().equals(PlatformDetection.ARCH_ARM64) && !PlatformDetection.isUsingRosseta())
INSTALL_PIP_DEPS = Arrays.asList(new String[] {"mkl=2023.2.2", "appose", "pytest", "samv2=0.0.4"});
INSTALL_PIP_DEPS = Arrays.asList(new String[] {"mkl==2024.0.0", "samv2==0.0.4", "pytest"});
else
INSTALL_PIP_DEPS = Arrays.asList(new String[] {"appose", "pytest", "samv2=0.0.4"});
INSTALL_PIP_DEPS = Arrays.asList(new String[] {"samv2==0.0.4", "pytest"});
}
/**
* Byte sizes of all the EfficientViTSAM options
Expand Down Expand Up @@ -210,7 +211,7 @@ public boolean checkModelWeightsInstalled() {
if (!Sam2.getListOfSupportedVariants().contains(modelType))
throw new IllegalArgumentException("The provided model is not one of the supported SAM2 models: "
+ Sam2.getListOfSupportedVariants());
File weightsFile = Paths.get(this.path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", modelType + ".pt").toFile();
File weightsFile = Paths.get(this.getModelWeigthPath()).toFile();
if (!weightsFile.isFile()) return false;
if (weightsFile.length() != SAM2_BYTE_SIZES_MAP.get(modelType)) return false;
return true;
Expand Down Expand Up @@ -304,11 +305,10 @@ public void installSAMDeps(boolean force) throws IOException, InterruptedExcepti
throw new IllegalArgumentException("Unable to install Python without first installing Mamba. ");
Thread thread = reportProgress(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- CREATING THE SAM2 PYTHON ENVIRONMENT WITH ITS DEPENDENCIES");
String[] pythonArgs = new String[] {"-c", "conda-forge", "python=3.11", "-c", "pytorch"};
String[] args = new String[pythonArgs.length + INSTALL_CONDA_DEPS.size() + INSTALL_CONDA_DEPS.size()];
String[] args = new String[pythonArgs.length + INSTALL_CONDA_DEPS.size()];
int c = 0;
for (String ss : pythonArgs) args[c ++] = ss;
for (String ss : INSTALL_CONDA_DEPS) args[c ++] = ss;
for (String ss : INSTALL_CONDA_DEPS) args[c ++] = ss;
if (!this.checkSAMDepsInstalled() || force) {
try {
mamba.create(SAM2_ENV_NAME, true, args);
Expand All @@ -321,10 +321,19 @@ public void installSAMDeps(boolean force) throws IOException, InterruptedExcepti
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED SAM2 PYTHON ENVIRONMENT CREATION");
throw e;
}
ArrayList<String> pipInstall = new ArrayList<String>();
for (String ss : new String[] {"-m", "pip", "install"}) pipInstall.add(ss);
for (String ss : INSTALL_PIP_DEPS) pipInstall.add(ss);
try {
Mamba.runPythonIn(Paths.get(path, "envs", SAM2_ENV_NAME).toFile(), pipInstall.stream().toArray( String[]::new ));
} catch (IOException | InterruptedException e) {
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED PYTHON ENVIRONMENT CREATION WHEN INSTALLING PIP DEPENDENCIES");
throw e;
}
}
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- SAM2 PYTHON ENVIRONMENT CREATED");
// TODO remove
installApposePackage(SAM2_ENV_NAME);
}

Expand All @@ -349,16 +358,6 @@ public void installEverything() throws IOException, InterruptedException,
if (!this.checkModelWeightsInstalled()) this.installModelWeigths();
}

/**
*
* @return the path to the EfficientSAM Small weights file
*/
public String getModelWeightsPath() {
File file = Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", modelType + ".pt").toFile();
if (!file.isFile()) return null;
return file.getAbsolutePath();
}

/**
*
* @return the the path to the Python environment needed to run EfficientSAM
Expand All @@ -374,16 +373,19 @@ public String getModelEnv() {
* @return the official name of the EfficientSAM Small weights
*/
public String getModelWeigthsName() {
return modelType + ".pt";
return "sam2_hiera_" + modelType + ".pt";
}

@Override
public String getModelWeigthPath() {
File file;
try {
return Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", DownloadModel.getFileNameFromURLString(String.format(SAM2_URL, modelType))).toAbsolutePath().toString();
file = Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", DownloadModel.getFileNameFromURLString(String.format(SAM2_URL, modelType))).toFile();
} catch (MalformedURLException e) {
return Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", String.format("sam2_hiera_%s.pt", modelType)).toAbsolutePath().toString();
file = Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", String.format("sam2_hiera_%s.pt", modelType)).toFile();
}

return file.getAbsolutePath();
}

@Override
Expand All @@ -399,8 +401,8 @@ public boolean checkEverythingInstalled() {

@Override
public void uninstall() {
if (new File(this.getModelWeightsPath()).getParentFile().list().length != 1)
Files.deleteFolder(new File(this.getModelWeightsPath()));
if (new File(this.getModelWeigthPath()).getParentFile().list().length != 1)
Files.deleteFolder(new File(this.getModelWeigthPath()));
else
Files.deleteFolder(new File(this.getModelEnv()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ public abstract class SamEnvManagerAbstract {
public abstract void uninstall();



public void setConsumer(Consumer<String> consumer) {
this.consumer = consumer;
}

/**
* Send information as Strings to the consumer
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public class Sam2 extends AbstractSamJ {
+ "from sam2.utils.misc import variant_to_config_mapping" + System.lineSeparator()
+ "task.update('imported')" + System.lineSeparator()
+ "model = build_sam2(variant_to_config_mapping['%s'],'%s')" + System.lineSeparator()
+ "predictor = SAM2ImagePredictor(model)"
+ "predictor = SAM2ImagePredictor(model)" + System.lineSeparator()
+ "task.update('created predictor')" + System.lineSeparator()
+ "encodings_map = {}" + System.lineSeparator()
+ "globals()['encodings_map'] = encodings_map" + System.lineSeparator()
Expand Down

0 comments on commit f1b4bfe

Please sign in to comment.