Skip to content

Commit

Permalink
add RLE compression of masks in python
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosuc3m committed Oct 25, 2024
1 parent 363fe2d commit 80fb4bc
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions src/main/java/ai/nets/samj/models/PythonMethods.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,47 @@ public class PythonMethods {
+ "globals()['trace_contour'] = trace_contour" + System.lineSeparator()
+ "globals()['get_polygons_from_binary_mask'] = get_polygons_from_binary_mask" + System.lineSeparator();

/**
* String containing a Python method to encode binary masks into a compressed object using the
* Run-Length Encoding (RLE) algorithm
*/
protected static String RLE_METHOD = ""
+ "def encode_rle(mask: Union[np.ndarray, List[List[int]]]) -> List[int]:" + System.lineSeparator()
+ " \"\"\"" + System.lineSeparator()
+ " Encode a binary mask using Run-Length Encoding (RLE)." + System.lineSeparator()
+ " " + System.lineSeparator()
+ " Args:" + System.lineSeparator()
+ " mask: A 2D binary array (numpy array or list of lists) where 1 represents the object" + System.lineSeparator()
+ " and 0 represents the background" + System.lineSeparator()
+ " " + System.lineSeparator()
+ " Returns:" + System.lineSeparator()
+ " List[int]: RLE encoding in the format [start1, length1, start2, length2, ...]" + System.lineSeparator()
+ " where start positions are 0-based" + System.lineSeparator()
+ " \"\"\"" + System.lineSeparator()
+ " if isinstance(mask, list):" + System.lineSeparator()
+ " mask = np.array(mask)" + System.lineSeparator()
+ " " + System.lineSeparator()
+ " # Flatten the mask in row-major order" + System.lineSeparator()
+ " binary = mask.flatten()" + System.lineSeparator()
+ " " + System.lineSeparator()
+ " # Find positions where values change" + System.lineSeparator()
+ " transitions = np.where(binary[1:] != binary[:-1])[0] + 1" + System.lineSeparator()
+ " transitions = np.concatenate(([0], transitions, [len(binary)]))" + System.lineSeparator()
+ " " + System.lineSeparator()
+ " # Initialize result" + System.lineSeparator()
+ " rle = []" + System.lineSeparator()
+ " " + System.lineSeparator()
+ " # Process each run" + System.lineSeparator()
+ " for i in range(len(transitions) - 1):" + System.lineSeparator()
+ " start = transitions[i]" + System.lineSeparator()
+ " length = transitions[i + 1] - transitions[i]" + System.lineSeparator()
+ " " + System.lineSeparator()
+ " # Only encode runs of 1s" + System.lineSeparator()
+ " if binary[start] == 1:" + System.lineSeparator()
+ " rle.extend([start, length])" + System.lineSeparator()
+ " " + System.lineSeparator()
+ " return rle" + System.lineSeparator();


protected static String SAM_EVERYTHING = ""
+ "def calculate_pairs(masks):\n"
Expand Down

0 comments on commit 80fb4bc

Please sign in to comment.