Skip to content

Commit

Permalink
improve mask creation filling holes
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 29, 2024
1 parent a6db425 commit 6d49168
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/main/java/ai/nets/samj/models/EfficientSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class EfficientSamJ extends AbstractSamJ {
+ "from skimage import measure" + System.lineSeparator()
+ "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator()
+ "import torch" + System.lineSeparator()
+ "from scipy.ndimage import binary_fill_holes" + System.lineSeparator()
+ "import sys" + System.lineSeparator()
+ "sys.path.append(r'%s')" + System.lineSeparator()
+ "from multiprocessing import shared_memory" + System.lineSeparator()
Expand All @@ -79,6 +80,7 @@ public class EfficientSamJ extends AbstractSamJ {
+ "globals()['measure'] = measure" + System.lineSeparator()
+ "globals()['np'] = np" + System.lineSeparator()
+ "globals()['torch'] = torch" + System.lineSeparator()
+ "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator()
+ "globals()['predictor'] = predictor" + System.lineSeparator();

/**
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/ai/nets/samj/models/EfficientViTSamJ.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ public class EfficientViTSamJ extends AbstractSamJ {
+ "from skimage import measure" + System.lineSeparator()
+ "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator()
+ "import torch" + System.lineSeparator()
+ "from scipy.ndimage import binary_fill_holes" + System.lineSeparator()
+ "import sys" + System.lineSeparator()
+ "import os" + System.lineSeparator()
+ "os.chdir(r'%s')" + System.lineSeparator()
Expand Down Expand Up @@ -107,6 +108,7 @@ public class EfficientViTSamJ extends AbstractSamJ {
+ "globals()['measure'] = measure" + System.lineSeparator()
+ "globals()['np'] = np" + System.lineSeparator()
+ "globals()['torch'] = torch" + System.lineSeparator()
+ "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator()
+ "globals()['predictor'] = predictor" + System.lineSeparator();
/**
* String containing the Python imports code after it has been formatted with the correct
Expand Down
3 changes: 1 addition & 2 deletions src/main/java/ai/nets/samj/models/PythonMethods.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,7 @@ public class PythonMethods {
+ " for obj in labels:" + System.lineSeparator()
+ " if obj.num_pixels >= at_least_of_this_size:" + System.lineSeparator()
+ " x_coords,y_coords = trace_contour(obj.image, obj.num_pixels, obj.bbox[1],obj.bbox[0])" + System.lineSeparator()
+ " rle = encode_rle(obj.image * 1)" + System.lineSeparator()
+ " print(np.array(rle)[1::2])" + System.lineSeparator()
+ " rle = encode_rle(binary_fill_holes(obj.image))" + System.lineSeparator()
+ " bbox_w = obj.bbox[3] - obj.bbox[1]" + System.lineSeparator()
+ " for i in range(0, len(rle), 2):" + System.lineSeparator()
+ " rle[i] = sam_result.shape[1] * (obj.bbox[0] + rle[i] // bbox_w) + obj.bbox[1] + rle[i] % bbox_w" + System.lineSeparator()
Expand Down
2 changes: 2 additions & 0 deletions src/main/java/ai/nets/samj/models/Sam2.java
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ public class Sam2 extends AbstractSamJ {
+ "from skimage import measure" + System.lineSeparator()
+ "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator()
+ "import torch" + System.lineSeparator()
+ "from scipy.ndimage import binary_fill_holes" + System.lineSeparator()
+ "import sys" + System.lineSeparator()
+ "import os" + System.lineSeparator()
+ "from multiprocessing import shared_memory" + System.lineSeparator()
Expand All @@ -92,6 +93,7 @@ public class Sam2 extends AbstractSamJ {
+ "globals()['measure'] = measure" + System.lineSeparator()
+ "globals()['np'] = np" + System.lineSeparator()
+ "globals()['torch'] = torch" + System.lineSeparator()
+ "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator()
+ "globals()['predictor'] = predictor" + System.lineSeparator();
/**
* String containing the Python imports code after it has been formated with the correct
Expand Down

0 comments on commit 6d49168

Please sign in to comment.