Skip to content

Commit

Permalink
correct code for sam2
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Aug 7, 2024
1 parent 06ddabb commit fc7258a
Showing 1 changed file with 8 additions and 19 deletions.
27 changes: 8 additions & 19 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,25 +78,14 @@ public class Sam2 extends AbstractSamJ {
+ "import torch" + System.lineSeparator()
+ "import sys" + System.lineSeparator()
+ "import os" + System.lineSeparator()
+ "os.chdir(r'%s')" + System.lineSeparator()
+ "from multiprocessing import shared_memory" + System.lineSeparator()
+ "task.update('import sam')" + System.lineSeparator()
+ "from efficientvit.models.efficientvit import EfficientViTSam, %s" + System.lineSeparator()
+ "from efficientvit.models.efficientvit.sam import EfficientViTSamPredictor" + System.lineSeparator()
+ "from sam2.build_sam import build_sam2" + System.lineSeparator()
+ "from sam2.sam2_image_predictor import SAM2ImagePredictor" + System.lineSeparator()
+ "from sam2.utils.misc import variant_to_config_mapping" + System.lineSeparator()
+ "task.update('imported')" + System.lineSeparator()
+ "" + System.lineSeparator()
+ "model = %s().cpu().eval()" + System.lineSeparator()
+ "eps = 1e-6" + System.lineSeparator()
+ "for m in model.modules():" + System.lineSeparator()
+ " if isinstance(m, (torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.modules.batchnorm._BatchNorm)):" + System.lineSeparator()
+ " if eps is not None:" + System.lineSeparator()
+ " m.eps = eps" + System.lineSeparator()
+ "f_name = os.path.realpath(os.path.expanduser(r'%s'))" + System.lineSeparator()
+ "weight = torch.load(f_name, map_location='cpu')" + System.lineSeparator()
+ "if \"state_dict\" in weight:" + System.lineSeparator()
+ " weight = weight[\"state_dict\"]" + System.lineSeparator()
+ "model.load_state_dict(weight)" + System.lineSeparator()
+ "predictor = EfficientViTSamPredictor(model)" + System.lineSeparator()
+ "model = build_sam2(variant_to_config_mapping['%s'],'%s')" + System.lineSeparator()
+ "predictor = SAM2ImagePredictor(model)"
+ "task.update('created predictor')" + System.lineSeparator()
+ "encodings_map = {}" + System.lineSeparator()
+ "globals()['encodings_map'] = encodings_map" + System.lineSeparator()
Expand Down Expand Up @@ -158,9 +147,9 @@ private Sam2(SamEnvManagerAbstract manager, String type,
};
python = env.python();
python.debug(debugPrinter::printText);
IMPORTS_FORMATED = String.format(IMPORTS,
manager.getModelEnv() + File.separator + Sam2EnvManager.SAM2_ENV_NAME,
type, type, manager.getModelWeigthsName());
IMPORTS_FORMATED = String.format(IMPORTS, type,
manager.getModelEnv() + File.separator + Sam2EnvManager.SAM2_ENV_NAME
+ File.separator + manager.getModelWeigthsName());

printScript(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES, "Edges tracing code");
Task task = python.task(IMPORTS_FORMATED + PythonMethods.TRACE_EDGES);
Expand Down

0 comments on commit fc7258a

Please sign in to comment.