From 5fdb3fb49572e3d2ef260d158ac9e045e3ceae49 Mon Sep 17 00:00:00 2001 From: Onuralp SEZER Date: Fri, 8 Nov 2024 09:40:30 +0300 Subject: [PATCH 1/2] =?UTF-8?q?feat(detections):=20=E2=9C=A8=20paligemma?= =?UTF-8?q?=20segmentation=20support=20added?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Onuralp SEZER --- supervision/detection/core.py | 5 ++- supervision/detection/lmm.py | 77 ++++++++++++++++++++++++++++------- 2 files changed, 65 insertions(+), 17 deletions(-) diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 32753a30a..512fd5fbf 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -840,9 +840,10 @@ def from_lmm( if lmm == LMM.PALIGEMMA: assert isinstance(result, str) - xyxy, class_id, class_name = from_paligemma(result, **kwargs) + xyxy, class_id, class_name, mask = from_paligemma(result, **kwargs) data = {CLASS_NAME_DATA_FIELD: class_name} - return cls(xyxy=xyxy, class_id=class_id, data=data) + mask = mask if mask is not None else None + return cls(xyxy=xyxy, class_id=class_id, mask=mask, data=data) if lmm == LMM.FLORENCE_2: assert isinstance(result, dict) diff --git a/supervision/detection/lmm.py b/supervision/detection/lmm.py index 7879902f3..38ccaf16a 100644 --- a/supervision/detection/lmm.py +++ b/supervision/detection/lmm.py @@ -69,25 +69,72 @@ def validate_lmm_parameters( def from_paligemma( result: str, resolution_wh: Tuple[int, int], classes: Optional[List[str]] = None -) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray]: +) -> Tuple[np.ndarray, Optional[np.ndarray], np.ndarray, Optional[np.ndarray]]: + """ + Parse results from Paligemma model which can contain object detection and segmentation. + + Args: + result (str): Model output string containing loc and optional seg tokens + resolution_wh (Tuple[int, int]): Target resolution (width, height) + classes (Optional[List[str]]): List of class names to filter results + + Returns: + xyxy (np.ndarray): Bounding box coordinates + class_id (Optional[np.ndarray]): Class IDs if classes provided + class_name (np.ndarray): Class names + mask (Optional[np.ndarray]): Segmentation masks if available + """ # noqa: E501 w, h = resolution_wh - pattern = re.compile( - r"(?) ([\w\s\-]+)" - ) - matches = pattern.findall(result) - matches = np.array(matches) if matches else np.empty((0, 5)) - xyxy, class_name = matches[:, [1, 0, 3, 2]], matches[:, 4] - xyxy = xyxy.astype(int) / 1024 * np.array([w, h, w, h]) - class_name = np.char.strip(class_name.astype(str)) - class_id = None + segmentation_pattern = re.compile( + r"\s*" + + "".join(r"" for _ in range(16)) + + r"\s+([\w\s\-]+)" + ) - if classes is not None: - mask = np.array([name in classes for name in class_name]).astype(bool) - xyxy, class_name = xyxy[mask], class_name[mask] - class_id = np.array([classes.index(name) for name in class_name]) + detection_pattern = re.compile( + r"(?) ([\w\s\-]+)" + ) - return xyxy, class_id, class_name + segmentation_matches = segmentation_pattern.findall(result) + if segmentation_matches: + matches = np.array(segmentation_matches) + xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h]) + class_name = np.char.strip(matches[:, -1].astype(str)) + seg_tokens = matches[:, 4:-1].astype(int) + masks = [np.zeros((h, w), dtype=bool) for tokens in seg_tokens] + masks = np.array(masks) + + class_id = None + if classes is not None: + mask = np.array([name in classes for name in class_name]).astype(bool) + xyxy = xyxy[mask] + class_name = class_name[mask] + masks = masks[mask] + class_id = np.array([classes.index(name) for name in class_name]) + + return xyxy, class_id, class_name, masks + + detection_matches = detection_pattern.findall(result) + if detection_matches: + matches = np.array(detection_matches) + xyxy = matches[:, [1, 0, 3, 2]].astype(int) / 1024 * np.array([w, h, w, h]) + class_name = np.char.strip(matches[:, 4].astype(str)) + + class_id = None + if classes is not None: + mask = np.array([name in classes for name in class_name]).astype(bool) + xyxy, class_name = xyxy[mask], class_name[mask] + class_id = np.array([classes.index(name) for name in class_name]) + + return xyxy, class_id, class_name, None + + return ( + np.empty((0, 4), dtype=float), + None, + np.array([], dtype=str), + None + ) def from_florence_2( From 210c3e6417b580ee47fbb54f22e3cb8eae75c9bf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 8 Nov 2024 06:59:36 +0000 Subject: [PATCH 2/2] =?UTF-8?q?fix(pre=5Fcommit):=20=F0=9F=8E=A8=20auto=20?= =?UTF-8?q?format=20pre-commit=20hooks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- supervision/detection/lmm.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/supervision/detection/lmm.py b/supervision/detection/lmm.py index 38ccaf16a..18f0c6ec2 100644 --- a/supervision/detection/lmm.py +++ b/supervision/detection/lmm.py @@ -129,12 +129,7 @@ def from_paligemma( return xyxy, class_id, class_name, None - return ( - np.empty((0, 4), dtype=float), - None, - np.array([], dtype=str), - None - ) + return (np.empty((0, 4), dtype=float), None, np.array([], dtype=str), None) def from_florence_2(