Skip to content

Commit

Permalink
keep integrating SAM2
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Aug 7, 2024
1 parent 7ea13a3 commit 4aebcea
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 243 deletions.
108 changes: 19 additions & 89 deletions src/main/java/ai/nets/samj/install/Sam2EnvManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

import org.apache.commons.compress.archivers.ArchiveException;

import ai.nets.samj.models.EfficientViTSamJ;
import ai.nets.samj.models.Sam2;
import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.apposed.appose.MambaInstallerUtils;
Expand Down Expand Up @@ -204,27 +204,15 @@ public boolean checkSAMDepsInstalled() {
return uninstalled.size() == 0;
}

/**
* Check whether the Python package to run EfficientSAM has been installed. The package will be in the folder
* {@value #ESAM_ENV_NAME}. The Python executable and other dependencies will be at {@value #COMMON_ENV_NAME}
* @return whether the Python package to run EfficientSAM has been installed.
*/
public boolean checkEfficientViTSAMPackageInstalled() {
if (!checkMambaInstalled()) return false;
File pythonEnv = Paths.get(this.path, "envs", SAM2_ENV_NAME, SAM2_NAME).toFile();
if (!pythonEnv.exists() || pythonEnv.list().length <= 1) return false;
return true;
}

/**
*
* @return whether the weights needed to run EfficientSAM Small (the standard EfficientSAM) have been
* downloaded and installed or not
*/
public boolean checkModelWeightsInstalled() {
if (!EfficientViTSamJ.getListOfSupportedEfficientViTSAM().contains(modelType))
throw new IllegalArgumentException("The provided model is not one of the supported EfficientViT models: "
+ EfficientViTSamJ.getListOfSupportedEfficientViTSAM());
if (!Sam2.getListOfSupportedVariants().contains(modelType))
throw new IllegalArgumentException("The provided model is not one of the supported SAM2 models: "
+ Sam2.getListOfSupportedVariants());
File weightsFile = Paths.get(this.path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", modelType + ".pt").toFile();
if (!weightsFile.isFile()) return false;
if (weightsFile.length() != SAM2_BYTE_SIZES_MAP.get(modelType)) return false;
Expand All @@ -246,12 +234,12 @@ public void installModelWeigths() throws IOException, InterruptedException {
* @throws InterruptedException if the download of weights is interrupted
*/
public void installModelWeigths(boolean force) throws IOException, InterruptedException {
if (!EfficientViTSamJ.getListOfSupportedEfficientViTSAM().contains(modelType))
throw new IllegalArgumentException("The provided model is not one of the supported EfficientViT models: "
+ EfficientViTSamJ.getListOfSupportedEfficientViTSAM());
if (!Sam2.getListOfSupportedVariants().contains(modelType))
throw new IllegalArgumentException("The provided model is not one of the supported SAM2 models: "
+ Sam2.getListOfSupportedVariants());
if (!force && this.checkModelWeightsInstalled())
return;
Thread thread = reportProgress(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- INSTALLING EFFICIENTVITSAM WEIGHTS (" + modelType + ")");
Thread thread = reportProgress(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- INSTALLING SAM2 WEIGHTS (" + modelType + ")");
try {
File file = Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", DownloadModel.getFileNameFromURLString(String.format(SAM2_URL, modelType))).toFile();
file.getParentFile().mkdirs();
Expand All @@ -269,22 +257,22 @@ public void installModelWeigths(boolean force) throws IOException, InterruptedEx
while (downloadThread.isAlive()) {
try {Thread.sleep(280);} catch (InterruptedException e) {break;}
double progress = Math.round( (double) 100 * file.length() / size );
if (progress < 0 || progress > 100) passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- EFFICIENTVITSAM WEIGHTS DOWNLOAD: UNKNOWN%");
else passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- EFFICIENTVITSAM WEIGHTS DOWNLOAD: " + progress + "%");
if (progress < 0 || progress > 100) passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- SAM2 WEIGHTS DOWNLOAD: UNKNOWN%");
else passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- SAM2 WEIGHTS DOWNLOAD: " + progress + "%");
}
if (size != file.length())
throw new IOException("Model EfficientViTSAM-" + modelType + " was not correctly downloaded");
throw new IOException("Model SAM2" + modelType + " was not correctly downloaded");
} catch (IOException ex) {
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED EFFICIENTVITSAM WEIGHTS INSTALLATION");
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED SAM2 WEIGHTS INSTALLATION");
throw ex;
} catch (URISyntaxException e1) {
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED EFFICIENTVITSAM WEIGHTS INSTALLATION");
throw new IOException("Unable to find the download URL for EfficientViTSAM " + modelType + ": " + String.format(SAM2_URL, modelType));
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED SAM2 WEIGHTS INSTALLATION");
throw new IOException("Unable to find the download URL for SAM2 " + modelType + ": " + String.format(SAM2_URL, modelType));

}
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- EFFICIENTVITSAM WEIGHTS INSTALLED");
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- SAM2 WEIGHTS INSTALLED");
}

/**
Expand Down Expand Up @@ -317,7 +305,7 @@ public void installSAMDeps() throws IOException, InterruptedException, ArchiveEx
public void installSAMDeps(boolean force) throws IOException, InterruptedException, MambaInstallException {
if (!checkMambaInstalled())
throw new IllegalArgumentException("Unable to install Python without first installing Mamba. ");
Thread thread = reportProgress(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- CREATING THE EFFICIENTVITSAM PYTHON ENVIRONMENT WITH ITS DEPENDENCIES");
Thread thread = reportProgress(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- CREATING THE SAM2 PYTHON ENVIRONMENT WITH ITS DEPENDENCIES");
String[] pythonArgs = new String[] {"-c", "conda-forge", "python=3.11", "-c", "pytorch"};
String[] args = new String[pythonArgs.length + INSTALL_CONDA_DEPS.size() + INSTALL_CONDA_DEPS.size()];
int c = 0;
Expand All @@ -329,76 +317,20 @@ public void installSAMDeps(boolean force) throws IOException, InterruptedExcepti
mamba.create(SAM2_ENV_NAME, true, args);
} catch (MambaInstallException e) {
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED EFFICIENTVITSAM PYTHON ENVIRONMENT CREATION");
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED SAM2 PYTHON ENVIRONMENT CREATION");
throw new MambaInstallException("Unable to install Python without first installing Mamba. ");
} catch (IOException | InterruptedException e) {
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED EFFICIENTVITSAM PYTHON ENVIRONMENT CREATION");
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED SAM2 PYTHON ENVIRONMENT CREATION");
throw e;
}
}
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- EFFICIENTVITSAM PYTHON ENVIRONMENT CREATED");
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- SAM2 PYTHON ENVIRONMENT CREATED");
// TODO remove
installApposePackage(SAM2_ENV_NAME);
}

/**
* Install the Python package to run EfficientSAM.
* Does not overwrite the package if it already exists.
* @throws IOException if there is any file creation related issue
* @throws InterruptedException if the package installation is interrupted
* @throws MambaInstallException if there is any error with the Mamba installation
*/
public void installEfficientViTSAMPackage() throws IOException, InterruptedException, MambaInstallException {
installEfficientViTSAMPackage(false);
}

/**
* Install the Python package to run EfficientSAM
* @param force
* if the package already exists, whether to overwrite it or not
* @throws IOException if there is any file creation related issue
* @throws InterruptedException if the package installation is interrupted
* @throws MambaInstallException if there is any error with the Mamba installation
*/
public void installEfficientViTSAMPackage(boolean force) throws IOException, InterruptedException, MambaInstallException {
if (checkEfficientViTSAMPackageInstalled() && !force)
return;
if (!checkMambaInstalled())
throw new IllegalArgumentException("Unable to EfficientViTSAM without first installing Mamba. ");
Thread thread = reportProgress(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- INSTALLING 'EFFICIENTVITSAM' PYTHON PACKAGE");
String zipResourcePath = "efficientvit.zip";
String outputDirectory = mamba.getEnvsDir() + File.separator + SAM2_ENV_NAME;
try (
InputStream zipInputStream = SamEnvManager.class.getClassLoader().getResourceAsStream(zipResourcePath);
ZipInputStream zipInput = new ZipInputStream(zipInputStream);
) {
ZipEntry entry;
while ((entry = zipInput.getNextEntry()) != null) {
File entryFile = new File(outputDirectory + File.separator + entry.getName());
if (entry.isDirectory()) {
entryFile.mkdirs();
continue;
}
entryFile.getParentFile().mkdirs();
try (OutputStream entryOutput = new FileOutputStream(entryFile)) {
byte[] buffer = new byte[1024];
int bytesRead;
while ((bytesRead = zipInput.read(buffer)) != -1) {
entryOutput.write(buffer, 0, bytesRead);
}
}
}
} catch (IOException e) {
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED 'EFFICIENTVITSAM' PYTHON PACKAGE INSTALLATION");
throw e;
}
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- 'EFFICIENTVITSAM' PYTHON PACKAGE INSATLLED");
}

/**
* Install all the requirements to run EfficientSAM. First, checks if micromamba is installed, if not installs it;
* then checks if the Python environment and packages needed to run EfficientSAM are installed and if not installs it
Expand All @@ -417,8 +349,6 @@ public void installEverything() throws IOException, InterruptedException,

if (!this.checkSAMDepsInstalled()) this.installSAMDeps();

if (!this.checkEfficientViTSAMPackageInstalled()) this.installEfficientViTSAMPackage();

if (!this.checkModelWeightsInstalled()) this.installModelWeigths();
}

Expand Down
Loading

0 comments on commit 4aebcea

Please sign in to comment.