Skip to content

Commit

Permalink
add missing method
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Aug 7, 2024
1 parent fc7258a commit d6fc080
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 54 deletions.
16 changes: 6 additions & 10 deletions src/main/java/ai/nets/samj/install/EfficientSamEnvManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -294,21 +294,12 @@ public void installSAMDeps(boolean force) throws IOException, InterruptedExcepti
}
ArrayList<String> pipInstall = new ArrayList<String>();
for (String ss : new String[] {"-m", "pip", "install"}) pipInstall.add(ss);
// TODO until appose new release for (String ss : INSTALL_PIP_DEPS) pipInstall.add(ss);
/*
try {
Mamba.runPythonIn(Paths.get(path, "envs", COMMON_ENV_NAME).toFile(), pipInstall.stream().toArray( String[]::new ));
} catch (IOException | InterruptedException e) {
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- FAILED PYTHON ENVIRONMENT CREATION WHEN INSTALLING PIP DEPENDENCIES");
throw e;
}
*/
}
thread.interrupt();
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- PYTHON ENVIRONMENT CREATED");
// TODO remove
installApposePackage(ESAM_ENV_NAME);
installEfficientSAMPackage(force);
}

/**
Expand Down Expand Up @@ -428,4 +419,9 @@ public String getModelEnv() {
public String getModelWeigthsName() {
return ESAM_SMALL_WEIGHTS_NAME;
}

@Override
public String getModelWeigthPath() {
return Paths.get(this.path, "envs", ESAM_ENV_NAME, ESAM_NAME, "weights", ESAM_SMALL_WEIGHTS_NAME).toAbsolutePath().toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ public void installSAMDeps(boolean force) throws IOException, InterruptedExcepti
passToConsumer(LocalDateTime.now().format(DATE_FORMAT).toString() + " -- EFFICIENTVITSAM PYTHON ENVIRONMENT CREATED");
// TODO remove
installApposePackage(EVITSAM_ENV_NAME);
installEfficientViTSAMPackage(force);
}

private void installOnnxsim(File envFile) throws IOException, InterruptedException {
Expand Down Expand Up @@ -414,7 +415,7 @@ public void installEfficientViTSAMPackage(boolean force) throws IOException, Int
String zipResourcePath = "efficientvit.zip";
String outputDirectory = mamba.getEnvsDir() + File.separator + EVITSAM_ENV_NAME;
try (
InputStream zipInputStream = SamEnvManager.class.getClassLoader().getResourceAsStream(zipResourcePath);
InputStream zipInputStream = EfficientViTSamEnvManager.class.getClassLoader().getResourceAsStream(zipResourcePath);
ZipInputStream zipInput = new ZipInputStream(zipInputStream);
) {
ZipEntry entry;
Expand Down Expand Up @@ -492,4 +493,9 @@ public String getModelEnv() {
public String getModelWeigthsName() {
return modelType + ".pt";
}

@Override
public String getModelWeigthPath() {
return Paths.get(this.path, "envs", EVITSAM_ENV_NAME, EVITSAM_NAME, "weights", modelType + ".pt").toAbsolutePath().toString();
}
}
15 changes: 10 additions & 5 deletions src/main/java/ai/nets/samj/install/Sam2EnvManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.FileOutputStream;
import java.net.MalformedURLException;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.file.Paths;
Expand All @@ -32,8 +30,6 @@
import java.util.HashMap;
import java.util.List;
import java.util.function.Consumer;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;

import io.bioimage.modelrunner.system.PlatformDetection;

Expand Down Expand Up @@ -379,4 +375,13 @@ public String getModelEnv() {
public String getModelWeigthsName() {
return modelType + ".pt";
}

@Override
public String getModelWeigthPath() {
try {
return Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", DownloadModel.getFileNameFromURLString(String.format(SAM2_URL, modelType))).toAbsolutePath().toString();
} catch (MalformedURLException e) {
return Paths.get(path, "envs", SAM2_ENV_NAME, SAM2_NAME, "weights", String.format("sam2_hiera_%s.pt", modelType)).toAbsolutePath().toString();
}
}
}
2 changes: 2 additions & 0 deletions src/main/java/ai/nets/samj/install/SamEnvManagerAbstract.java
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ public abstract class SamEnvManagerAbstract {

public abstract String getModelWeigthsName();

public abstract String getModelWeigthPath();

public abstract String getModelEnv();


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
*
* @author Carlos Javier Garcia Lopez de Haro
*/
public class SamEnvManager {
public class SamEnvManager_old {
/**
* Name of the file that contains the weights of SAM Huge
*/
Expand Down Expand Up @@ -221,7 +221,7 @@ public class SamEnvManager {
* the path where the corresponding micromamba shuold be installed
* @return an instance of {@link SamEnvManager}
*/
public static SamEnvManager create(String path) {
public static SamEnvManager_old create(String path) {
return create(path, (ss) -> {});
}

Expand All @@ -235,8 +235,8 @@ public static SamEnvManager create(String path) {
* an specific consumer where info about the installation is going to be communicated
* @return an instance of {@link SamEnvManager}
*/
public static SamEnvManager create(String path, Consumer<String> consumer) {
SamEnvManager installer = new SamEnvManager();
public static SamEnvManager_old create(String path, Consumer<String> consumer) {
SamEnvManager_old installer = new SamEnvManager_old();
installer.path = path;
installer.consumer = consumer;
installer.mamba = new Mamba(path);
Expand All @@ -249,7 +249,7 @@ public static SamEnvManager create(String path, Consumer<String> consumer) {
* Micromamba does not need to be installed as the code will install it automatically.
* @return an instance of {@link SamEnvManager}
*/
public static SamEnvManager create() {
public static SamEnvManager_old create() {
return create(DEFAULT_DIR);
}

Expand All @@ -261,7 +261,7 @@ public static SamEnvManager create() {
* an specific consumer where info about the installation is going to be communicated
* @return an instance of {@link SamEnvManager}
*/
public static SamEnvManager create(Consumer<String> consumer) {
public static SamEnvManager_old create(Consumer<String> consumer) {
return create(DEFAULT_DIR, consumer);
}

Expand Down Expand Up @@ -442,7 +442,7 @@ public void downloadESAMSmallWeights(boolean force) throws IOException, Interrup
String zipResourcePath = "efficient_sam_vits.pt.zip";
String outputDirectory = Paths.get(path, "envs", ESAM_ENV_NAME, ESAM_NAME, "weights").toFile().getAbsolutePath();
try (
InputStream zipInputStream = SamEnvManager.class.getClassLoader().getResourceAsStream(zipResourcePath);
InputStream zipInputStream = SamEnvManager_old.class.getClassLoader().getResourceAsStream(zipResourcePath);
ZipInputStream zipInput = new ZipInputStream(zipInputStream);
) {
ZipEntry entry;
Expand Down Expand Up @@ -743,7 +743,7 @@ private void installApposePackage(String envName, boolean force) throws IOExcept
String zipResourcePath = "appose-python.zip";
String outputDirectory = mamba.getEnvsDir() + File.separator + envName;
try (
InputStream zipInputStream = SamEnvManager.class.getClassLoader().getResourceAsStream(zipResourcePath);
InputStream zipInputStream = SamEnvManager_old.class.getClassLoader().getResourceAsStream(zipResourcePath);
ZipInputStream zipInput = new ZipInputStream(zipInputStream);
) {
ZipEntry entry;
Expand Down Expand Up @@ -810,7 +810,7 @@ public void installSAMPackage(boolean force) throws IOException, InterruptedExce
String zipResourcePath = "SAM.zip";
String outputDirectory = mamba.getEnvsDir() + File.separator + SAM_ENV_NAME + File.separator + SAM_NAME;
try (
InputStream zipInputStream = SamEnvManager.class.getResourceAsStream(zipResourcePath);
InputStream zipInputStream = SamEnvManager_old.class.getResourceAsStream(zipResourcePath);
ZipInputStream zipInput = new ZipInputStream(zipInputStream);
) {
ZipEntry entry;
Expand Down Expand Up @@ -873,7 +873,7 @@ public void installEfficientSAMPackage(boolean force) throws IOException, Interr
String zipResourcePath = "EfficientSAM.zip";
String outputDirectory = mamba.getEnvsDir() + File.separator + ESAM_ENV_NAME;
try (
InputStream zipInputStream = SamEnvManager.class.getClassLoader().getResourceAsStream(zipResourcePath);
InputStream zipInputStream = SamEnvManager_old.class.getClassLoader().getResourceAsStream(zipResourcePath);
ZipInputStream zipInput = new ZipInputStream(zipInputStream);
) {
ZipEntry entry;
Expand Down Expand Up @@ -931,7 +931,7 @@ public void installEfficientViTSAMPackage(boolean force) throws IOException, Int
String zipResourcePath = "efficientvit.zip";
String outputDirectory = mamba.getEnvsDir() + File.separator + EVITSAM_ENV_NAME;
try (
InputStream zipInputStream = SamEnvManager.class.getClassLoader().getResourceAsStream(zipResourcePath);
InputStream zipInputStream = SamEnvManager_old.class.getClassLoader().getResourceAsStream(zipResourcePath);
ZipInputStream zipInput = new ZipInputStream(zipInputStream);
) {
ZipEntry entry;
Expand Down
20 changes: 10 additions & 10 deletions src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
*/
package ai.nets.samj.models;

import ai.nets.samj.install.SamEnvManager;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import ai.nets.samj.install.EfficientSamEnvManager;
import ai.nets.samj.install.SamEnvManagerAbstract;
import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Service.Task;
import io.bioimage.modelrunner.apposed.appose.Service.TaskStatus;
Expand Down Expand Up @@ -90,7 +90,7 @@ public class EfficientSamJ extends AbstractSamJ {
* @throws RuntimeException if there is any error running the Python code
* @throws InterruptedException if the process is interrupted
*/
private EfficientSamJ(SamEnvManager manager) throws IOException, RuntimeException, InterruptedException {
private EfficientSamJ(SamEnvManagerAbstract manager) throws IOException, RuntimeException, InterruptedException {
this(manager, (t) -> {}, false);
}

Expand All @@ -108,21 +108,21 @@ private EfficientSamJ(SamEnvManager manager) throws IOException, RuntimeExceptio
* @throws InterruptedException if the process is interrupted
*
*/
private EfficientSamJ(SamEnvManager manager,
private EfficientSamJ(SamEnvManagerAbstract manager,
final DebugTextPrinter debugPrinter,
final boolean printPythonCode) throws IOException, RuntimeException, InterruptedException {

this.debugPrinter = debugPrinter;
this.isDebugging = printPythonCode;

this.env = new Environment() {
@Override public String base() { return manager.getEfficientSAMPythonEnv(); }
@Override public String base() { return manager.getModelEnv(); }
};
python = env.python();
python.debug(debugPrinter::printText);
String IMPORTS_FORMATED = String.format(IMPORTS,
manager.getEfficientSamEnv() + File.separator + SamEnvManager.ESAM_NAME,
manager.getEfficientSAMSmallWeightsPath());
manager.getModelEnv() + File.separator + EfficientSamEnvManager.ESAM_NAME,
manager.getModelWeigthPath());
printScript(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES, "Edges tracing code");
Task task = python.task(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES);
System.out.println(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES);
Expand Down Expand Up @@ -156,7 +156,7 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> EfficientSamJ
initializeSam(SamEnvManager manager,
initializeSam(SamEnvManagerAbstract manager,
RandomAccessibleInterval<T> image,
final DebugTextPrinter debugPrinter,
final boolean printPythonCode) throws IOException, RuntimeException, InterruptedException {
Expand Down Expand Up @@ -191,7 +191,7 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> EfficientSamJ
initializeSam(SamEnvManager manager, RandomAccessibleInterval<T> image) throws IOException, RuntimeException, InterruptedException {
initializeSam(SamEnvManagerAbstract manager, RandomAccessibleInterval<T> image) throws IOException, RuntimeException, InterruptedException {
EfficientSamJ sam = null;
try{
sam = new EfficientSamJ(manager);
Expand Down Expand Up @@ -399,7 +399,7 @@ void adaptImageToModel(final RandomAccessibleInterval<T> ogImg, RandomAccessible
public static void main(String[] args) throws IOException, RuntimeException, InterruptedException {
RandomAccessibleInterval<UnsignedByteType> img = ArrayImgs.unsignedBytes(new long[] {50, 50, 3});
img = Views.addDimension(img, 1, 2);
try (EfficientSamJ sam = initializeSam(SamEnvManager.create(), img)) {
try (EfficientSamJ sam = initializeSam(EfficientSamEnvManager.create(), img)) {
sam.processBox(new int[] {0, 5, 10, 26});
}
}
Expand Down
27 changes: 14 additions & 13 deletions src/main/java/ai/nets/samj/models/EfficientViTSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
import java.util.List;
import java.util.stream.Collectors;

import ai.nets.samj.install.SamEnvManager;
import ai.nets.samj.install.EfficientViTSamEnvManager;
import ai.nets.samj.install.SamEnvManagerAbstract;

import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -124,7 +125,7 @@ public class EfficientViTSamJ extends AbstractSamJ {
* @throws RuntimeException if there is any error running the Python code
* @throws InterruptedException if the process is interrupted
*/
private EfficientViTSamJ(SamEnvManager manager, String type) throws IOException, RuntimeException, InterruptedException {
private EfficientViTSamJ(SamEnvManagerAbstract manager, String type) throws IOException, RuntimeException, InterruptedException {
this(manager, type, (t) -> {}, false);
}

Expand All @@ -144,7 +145,7 @@ private EfficientViTSamJ(SamEnvManager manager, String type) throws IOException,
* @throws InterruptedException if the process is interrupted
*
*/
private EfficientViTSamJ(SamEnvManager manager, String type,
private EfficientViTSamJ(SamEnvManagerAbstract manager, String type,
final DebugTextPrinter debugPrinter,
final boolean printPythonCode) throws IOException, RuntimeException, InterruptedException {

Expand All @@ -155,13 +156,13 @@ private EfficientViTSamJ(SamEnvManager manager, String type,
this.isDebugging = printPythonCode;

this.env = new Environment() {
@Override public String base() { return manager.getEfficientViTSamEnv(); }
@Override public String base() { return manager.getModelEnv(); }
};
python = env.python();
python.debug(debugPrinter::printText);
IMPORTS_FORMATED = String.format(IMPORTS,
manager.getEfficientViTSamEnv() + File.separator + SamEnvManager.EVITSAM_NAME,
MODELS_DICT.get(type), MODELS_DICT.get(type), manager.getEfficientViTSAMWeightsPath(type));
manager.getModelEnv() + File.separator + EfficientViTSamEnvManager.EVITSAM_NAME,
MODELS_DICT.get(type), MODELS_DICT.get(type), manager.getModelWeigthPath());

printScript(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES, "Edges tracing code");
Task task = python.task(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES);
Expand Down Expand Up @@ -199,7 +200,7 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> EfficientViTSamJ
initializeSam(String modelType, SamEnvManager manager,
initializeSam(String modelType, SamEnvManagerAbstract manager,
RandomAccessibleInterval<T> image,
final DebugTextPrinter debugPrinter,
final boolean printPythonCode) throws IOException, RuntimeException, InterruptedException {
Expand Down Expand Up @@ -234,7 +235,7 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> EfficientViTSamJ
initializeSam(String modelType, SamEnvManager manager, RandomAccessibleInterval<T> image)
initializeSam(String modelType, SamEnvManagerAbstract manager, RandomAccessibleInterval<T> image)
throws IOException, RuntimeException, InterruptedException {
EfficientViTSamJ sam = null;
try{
Expand Down Expand Up @@ -272,11 +273,11 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> EfficientViTSamJ
initializeSam(SamEnvManager manager,
initializeSam(SamEnvManagerAbstract manager,
RandomAccessibleInterval<T> image,
final DebugTextPrinter debugPrinter,
final boolean printPythonCode) throws IOException, RuntimeException, InterruptedException {
return initializeSam(SamEnvManager.DEFAULT_EVITSAM, manager, image, debugPrinter, printPythonCode);
return initializeSam(EfficientViTSamEnvManager.DEFAULT_EVITSAM, manager, image, debugPrinter, printPythonCode);
}

/**
Expand All @@ -299,8 +300,8 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> EfficientViTSamJ
initializeSam(SamEnvManager manager, RandomAccessibleInterval<T> image) throws IOException, RuntimeException, InterruptedException {
return initializeSam(SamEnvManager.DEFAULT_EVITSAM, manager, image);
initializeSam(SamEnvManagerAbstract manager, RandomAccessibleInterval<T> image) throws IOException, RuntimeException, InterruptedException {
return initializeSam(EfficientViTSamEnvManager.DEFAULT_EVITSAM, manager, image);
}

@Override
Expand Down Expand Up @@ -496,7 +497,7 @@ void adaptImageToModel(RandomAccessibleInterval<T> ogImg, RandomAccessibleInterv
public static void main(String[] args) throws IOException, RuntimeException, InterruptedException {
RandomAccessibleInterval<UnsignedByteType> img = ArrayImgs.unsignedBytes(new long[] {50, 50, 3});
img = Views.addDimension(img, 1, 2);
try (EfficientViTSamJ sam = initializeSam(SamEnvManager.create(), img)) {
try (EfficientViTSamJ sam = initializeSam(EfficientViTSamEnvManager.create(), img)) {
sam.processBox(new int[] {0, 5, 10, 26});
}
}
Expand Down
6 changes: 2 additions & 4 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ public class Sam2 extends AbstractSamJ {
+ "globals()['torch'] = torch" + System.lineSeparator()
+ "globals()['predictor'] = predictor" + System.lineSeparator();
/**
* String containing the Python imports code after it has been formatted with the correct
* String containing the Python imports code after it has been formated with the correct
* paths and names
*/
private String IMPORTS_FORMATED;
Expand Down Expand Up @@ -147,9 +147,7 @@ private Sam2(SamEnvManagerAbstract manager, String type,
};
python = env.python();
python.debug(debugPrinter::printText);
IMPORTS_FORMATED = String.format(IMPORTS, type,
manager.getModelEnv() + File.separator + Sam2EnvManager.SAM2_ENV_NAME
+ File.separator + manager.getModelWeigthsName());
IMPORTS_FORMATED = String.format(IMPORTS, type, manager.getModelWeigthPath());

printScript(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES, "Edges tracing code");
Task task = python.task(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES);
Expand Down

0 comments on commit d6fc080

Please sign in to comment.