Skip to content

Commit

Permalink
remove errors
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Aug 7, 2024
1 parent 4aebcea commit 06ddabb
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
package ai.nets.samj.models;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;

import ai.nets.samj.install.SamEnvManager;
import ai.nets.samj.install.Sam2EnvManager;
import ai.nets.samj.install.SamEnvManagerAbstract;

import java.io.File;
import java.io.IOException;
Expand Down Expand Up @@ -61,13 +60,12 @@ public class Sam2 extends AbstractSamJ {
/**
* Map that associates the key for each of the existing EfficientViTSAM models to its complete name
*/
private static final HashMap<String, String> MODELS_DICT = new HashMap<String, String>();
private static final List<String> MODELS_LIST = new ArrayList<String>();
static {
MODELS_DICT.put("l0", "efficientvit_sam_l0");
MODELS_DICT.put("l1", "efficientvit_sam_l1");
MODELS_DICT.put("l2", "efficientvit_sam_l2");
MODELS_DICT.put("xl0", "efficientvit_sam_xl0");
MODELS_DICT.put("xl1", "efficientvit_sam_xl1");
MODELS_LIST.add("tiny");
MODELS_LIST.add("small");
MODELS_LIST.add("base_plus");
MODELS_LIST.add("large");
}
/**
* All the Python imports and configurations needed to start using EfficientViTSAM.
Expand Down Expand Up @@ -124,7 +122,7 @@ public class Sam2 extends AbstractSamJ {
* @throws RuntimeException if there is any error running the Python code
* @throws InterruptedException if the process is interrupted
*/
private Sam2(SamEnvManager manager, String type) throws IOException, RuntimeException, InterruptedException {
private Sam2(SamEnvManagerAbstract manager, String type) throws IOException, RuntimeException, InterruptedException {
this(manager, type, (t) -> {}, false);
}

Expand All @@ -144,24 +142,25 @@ private Sam2(SamEnvManager manager, String type) throws IOException, RuntimeExce
* @throws InterruptedException if the process is interrupted
*
*/
private Sam2(SamEnvManager manager, String type,
private Sam2(SamEnvManagerAbstract manager, String type,
final DebugTextPrinter debugPrinter,
final boolean printPythonCode) throws IOException, RuntimeException, InterruptedException {

if (!MODELS_DICT.keySet().contains(type))
throw new IllegalArgumentException("The model type should be one of hte following: "
+ MODELS_DICT.keySet().stream().collect(Collectors.toList()));
if (type.equals("base")) type = "base_plus";
if (!MODELS_LIST.contains(type))
throw new IllegalArgumentException("The model type should be one of the following: "
+ MODELS_LIST);
this.debugPrinter = debugPrinter;
this.isDebugging = printPythonCode;

this.env = new Environment() {
@Override public String base() { return manager.getEfficientViTSamEnv(); }
@Override public String base() { return manager.getModelEnv(); }
};
python = env.python();
python.debug(debugPrinter::printText);
IMPORTS_FORMATED = String.format(IMPORTS,
manager.getEfficientViTSamEnv() + File.separator + SamEnvManager.EVITSAM_NAME,
MODELS_DICT.get(type), MODELS_DICT.get(type), manager.getEfficientViTSAMWeightsPath(type));
manager.getModelEnv() + File.separator + Sam2EnvManager.SAM2_ENV_NAME,
type, type, manager.getModelWeigthsName());

printScript(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES, "Edges tracing code");
Task task = python.task(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES);
Expand Down Expand Up @@ -199,7 +198,7 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> Sam2
initializeSam(String modelType, SamEnvManager manager,
initializeSam(String modelType, SamEnvManagerAbstract manager,
RandomAccessibleInterval<T> image,
final DebugTextPrinter debugPrinter,
final boolean printPythonCode) throws IOException, RuntimeException, InterruptedException {
Expand Down Expand Up @@ -234,7 +233,7 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> Sam2
initializeSam(String modelType, SamEnvManager manager, RandomAccessibleInterval<T> image)
initializeSam(String modelType, SamEnvManagerAbstract manager, RandomAccessibleInterval<T> image)
throws IOException, RuntimeException, InterruptedException {
Sam2 sam = null;
try{
Expand Down Expand Up @@ -272,11 +271,11 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> Sam2
initializeSam(SamEnvManager manager,
initializeSam(SamEnvManagerAbstract manager,
RandomAccessibleInterval<T> image,
final DebugTextPrinter debugPrinter,
final boolean printPythonCode) throws IOException, RuntimeException, InterruptedException {
return initializeSam(SamEnvManager.DEFAULT_EVITSAM, manager, image, debugPrinter, printPythonCode);
return initializeSam(Sam2EnvManager.DEFAULT_SAM2, manager, image, debugPrinter, printPythonCode);
}

/**
Expand All @@ -299,8 +298,8 @@ else if (task.status == TaskStatus.CRASHED)
* @throws InterruptedException if the process is interrupted
*/
public static <T extends RealType<T> & NativeType<T>> Sam2
initializeSam(SamEnvManager manager, RandomAccessibleInterval<T> image) throws IOException, RuntimeException, InterruptedException {
return initializeSam(SamEnvManager.DEFAULT_EVITSAM, manager, image);
initializeSam(SamEnvManagerAbstract manager, RandomAccessibleInterval<T> image) throws IOException, RuntimeException, InterruptedException {
return initializeSam(Sam2EnvManager.DEFAULT_SAM2, manager, image);
}

@Override
Expand Down Expand Up @@ -462,7 +461,7 @@ private <T extends RealType<T> & NativeType<T>> void checkImageIsFine(RandomAcce
* @return the list of EfficientViTSAM models that are supported
*/
public static List<String> getListOfSupportedVariants(){
return MODELS_DICT.keySet().stream().collect(Collectors.toList());
return MODELS_LIST;
}

private <T extends RealType<T> & NativeType<T>>
Expand Down Expand Up @@ -496,7 +495,7 @@ void adaptImageToModel(RandomAccessibleInterval<T> ogImg, RandomAccessibleInterv
public static void main(String[] args) throws IOException, RuntimeException, InterruptedException {
RandomAccessibleInterval<UnsignedByteType> img = ArrayImgs.unsignedBytes(new long[] {50, 50, 3});
img = Views.addDimension(img, 1, 2);
try (Sam2 sam = initializeSam(SamEnvManager.create(), img)) {
try (Sam2 sam = initializeSam(Sam2EnvManager.create(), img)) {
sam.processBox(new int[] {0, 5, 10, 26});
}
}
Expand Down

0 comments on commit 06ddabb

Please sign in to comment.