-
Notifications
You must be signed in to change notification settings - Fork 19.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Visualization utils] Add visualization utils for plotting images(pla…
…in, with bounding boxes and segmentation masks) (#20401) * api gen * add plot image gallery function * add `plot_ bounding_box_gallery` * correct label key * add segmentation mask draw and plot functions * few arg corrections and docstrings * nit * add missing args for plotting segmenation masks use cols for each mask to make aspect ratio of each subplot correct * add missing argument for color
- Loading branch information
Showing
12 changed files
with
777 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""DO NOT EDIT. | ||
This file was autogenerated. Do not edit it by hand, | ||
since your modifications would be overwritten. | ||
""" | ||
|
||
from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes | ||
from keras.src.visualization.draw_segmentation_masks import ( | ||
draw_segmentation_masks, | ||
) | ||
from keras.src.visualization.plot_bounding_box_gallery import ( | ||
plot_bounding_box_gallery, | ||
) | ||
from keras.src.visualization.plot_image_gallery import plot_image_gallery | ||
from keras.src.visualization.plot_segmentation_mask_gallery import ( | ||
plot_segmentation_mask_gallery, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
"""DO NOT EDIT. | ||
This file was autogenerated. Do not edit it by hand, | ||
since your modifications would be overwritten. | ||
""" | ||
|
||
from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes | ||
from keras.src.visualization.draw_segmentation_masks import ( | ||
draw_segmentation_masks, | ||
) | ||
from keras.src.visualization.plot_bounding_box_gallery import ( | ||
plot_bounding_box_gallery, | ||
) | ||
from keras.src.visualization.plot_image_gallery import plot_image_gallery | ||
from keras.src.visualization.plot_segmentation_mask_gallery import ( | ||
plot_segmentation_mask_gallery, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from keras.src.visualization import draw_bounding_boxes | ||
from keras.src.visualization import plot_image_gallery |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,177 @@ | ||
import numpy as np | ||
|
||
from keras.src import backend | ||
from keras.src import ops | ||
from keras.src.api_export import keras_export | ||
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501 | ||
convert_format, | ||
) | ||
|
||
try: | ||
import cv2 | ||
except ImportError: | ||
cv2 = None | ||
|
||
|
||
@keras_export("keras.visualization.draw_bounding_boxes") | ||
def draw_bounding_boxes( | ||
images, | ||
bounding_boxes, | ||
bounding_box_format, | ||
class_mapping=None, | ||
color=(128, 128, 128), | ||
line_thickness=2, | ||
text_thickness=1, | ||
font_scale=1.0, | ||
data_format=None, | ||
): | ||
"""Draws bounding boxes on images. | ||
This function draws bounding boxes on a batch of images. It supports | ||
different bounding box formats and can optionally display class labels | ||
and confidences. | ||
Args: | ||
images: A batch of images as a 4D tensor or NumPy array. Shape should be | ||
`(batch_size, height, width, channels)`. | ||
bounding_boxes: A dictionary containing bounding box data. Should have | ||
the following keys: | ||
- `boxes`: A tensor or array of shape `(batch_size, num_boxes, 4)` | ||
containing the bounding box coordinates in the specified format. | ||
- `labels`: A tensor or array of shape `(batch_size, num_boxes)` | ||
containing the class labels for each bounding box. | ||
- `confidences` (Optional): A tensor or array of shape | ||
`(batch_size, num_boxes)` containing the confidence scores for | ||
each bounding box. | ||
bounding_box_format: A string specifying the format of the bounding | ||
boxes. Refer [keras-io](TODO) | ||
class_mapping: A dictionary mapping class IDs (integers) to class labels | ||
(strings). Used to display class labels next to the bounding boxes. | ||
Defaults to None (no labels displayed). | ||
color: A tuple or list representing the RGB color of the bounding boxes. | ||
For example, `(255, 0, 0)` for red. Defaults to `(128, 128, 128)`. | ||
line_thickness: An integer specifying the thickness of the bounding box | ||
lines. Defaults to `2`. | ||
text_thickness: An integer specifying the thickness of the text labels. | ||
Defaults to `1`. | ||
font_scale: A float specifying the scale of the font used for text | ||
labels. Defaults to `1.0`. | ||
data_format: A string, either `"channels_last"` or `"channels_first"`, | ||
specifying the order of dimensions in the input images. Defaults to | ||
the `image_data_format` value found in your Keras config file at | ||
`~/.keras/keras.json`. If you never set it, then it will be | ||
"channels_last". | ||
Returns: | ||
A NumPy array of the annotated images with the bounding boxes drawn. | ||
The array will have the same shape as the input `images`. | ||
Raises: | ||
ValueError: If `images` is not a 4D tensor/array, if `bounding_boxes` is | ||
not a dictionary, or if `bounding_boxes` does not contain `"boxes"` | ||
and `"labels"` keys. | ||
TypeError: If `bounding_boxes` is not a dictionary. | ||
ImportError: If `cv2` (OpenCV) is not installed. | ||
""" | ||
|
||
if cv2 is None: | ||
raise ImportError( | ||
"The `draw_bounding_boxes` function requires the `cv2` package " | ||
" (OpenCV). Please install it with `pip install opencv-python`." | ||
) | ||
|
||
class_mapping = class_mapping or {} | ||
text_thickness = ( | ||
text_thickness or line_thickness | ||
) # Default text_thickness if not provided. | ||
data_format = data_format or backend.image_data_format() | ||
images_shape = ops.shape(images) | ||
if len(images_shape) != 4: | ||
raise ValueError( | ||
"`images` must be batched 4D tensor. " | ||
f"Received: images.shape={images_shape}" | ||
) | ||
if not isinstance(bounding_boxes, dict): | ||
raise TypeError( | ||
"`bounding_boxes` should be a dict. " | ||
f"Received: bounding_boxes={bounding_boxes} of type " | ||
f"{type(bounding_boxes)}" | ||
) | ||
if "boxes" not in bounding_boxes or "labels" not in bounding_boxes: | ||
raise ValueError( | ||
"`bounding_boxes` should be a dict containing 'boxes' and " | ||
f"'labels' keys. Received: bounding_boxes={bounding_boxes}" | ||
) | ||
if data_format == "channels_last": | ||
h_axis = -3 | ||
w_axis = -2 | ||
else: | ||
h_axis = -2 | ||
w_axis = -1 | ||
height = images_shape[h_axis] | ||
width = images_shape[w_axis] | ||
bounding_boxes = bounding_boxes.copy() | ||
bounding_boxes = convert_format( | ||
bounding_boxes, bounding_box_format, "xyxy", height, width | ||
) | ||
|
||
# To numpy array | ||
images = ops.convert_to_numpy(images).astype("uint8") | ||
boxes = ops.convert_to_numpy(bounding_boxes["boxes"]) | ||
labels = ops.convert_to_numpy(bounding_boxes["labels"]) | ||
if "confidences" in bounding_boxes: | ||
confidences = ops.convert_to_numpy(bounding_boxes["confidences"]) | ||
else: | ||
confidences = None | ||
|
||
result = [] | ||
batch_size = images.shape[0] | ||
for i in range(batch_size): | ||
_image = images[i] | ||
_box = boxes[i] | ||
_class = labels[i] | ||
for box_i in range(_box.shape[0]): | ||
x1, y1, x2, y2 = _box[box_i].astype("int32") | ||
c = _class[box_i].astype("int32") | ||
if c == -1: | ||
continue | ||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) | ||
c = int(c) | ||
# Draw bounding box | ||
cv2.rectangle(_image, (x1, y1), (x2, y2), color, line_thickness) | ||
|
||
if c in class_mapping: | ||
label = class_mapping[c] | ||
if confidences is not None: | ||
conf = confidences[i][box_i] | ||
label = f"{label} | {conf:.2f}" | ||
|
||
font_x1, font_y1 = _find_text_location( | ||
x1, y1, font_scale, text_thickness | ||
) | ||
cv2.putText( | ||
img=_image, | ||
text=label, | ||
org=(font_x1, font_y1), | ||
fontFace=cv2.FONT_HERSHEY_SIMPLEX, | ||
fontScale=font_scale, | ||
color=color, | ||
thickness=text_thickness, | ||
) | ||
result.append(_image) | ||
return np.stack(result, axis=0) | ||
|
||
|
||
def _find_text_location(x, y, font_scale, thickness): | ||
font_height = int(font_scale * 12) | ||
target_y = y - 8 | ||
if target_y - (2 * font_height) > 0: | ||
return x, y - 8 | ||
|
||
line_offset = thickness | ||
static_offset = 3 | ||
|
||
return ( | ||
x + static_offset, | ||
y + (2 * font_height) + line_offset + static_offset, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import numpy as np | ||
|
||
from keras.src import backend | ||
from keras.src import ops | ||
from keras.src.api_export import keras_export | ||
|
||
|
||
@keras_export("keras.visualization.draw_segmentation_masks") | ||
def draw_segmentation_masks( | ||
images, | ||
segmentation_masks, | ||
num_classes=None, | ||
color_mapping=None, | ||
alpha=0.8, | ||
blend=True, | ||
ignore_index=-1, | ||
data_format=None, | ||
): | ||
"""Draws segmentation masks on images. | ||
The function overlays segmentation masks on the input images. | ||
The masks are blended with the images using the specified alpha value. | ||
Args: | ||
images: A batch of images as a 4D tensor or NumPy array. Shape | ||
should be (batch_size, height, width, channels). | ||
segmentation_masks: A batch of segmentation masks as a 3D or 4D tensor | ||
or NumPy array. Shape should be (batch_size, height, width) or | ||
(batch_size, height, width, 1). The values represent class indices | ||
starting from 1 up to `num_classes`. Class 0 is reserved for | ||
the background and will be ignored if `ignore_index` is not 0. | ||
num_classes: The number of segmentation classes. If `None`, it is | ||
inferred from the maximum value in `segmentation_masks`. | ||
color_mapping: A dictionary mapping class indices to RGB colors. | ||
If `None`, a default color palette is generated. The keys should be | ||
integers starting from 1 up to `num_classes`. | ||
alpha: The opacity of the segmentation masks. Must be in the range | ||
`[0, 1]`. | ||
blend: Whether to blend the masks with the input image using the | ||
`alpha` value. If `False`, the masks are drawn directly on the | ||
images without blending. Defaults to `True`. | ||
ignore_index: The class index to ignore. Mask pixels with this value | ||
will not be drawn. Defaults to -1. | ||
data_format: Image data format, either `"channels_last"` or | ||
`"channels_first"`. Defaults to the `image_data_format` value found | ||
in your Keras config file at `~/.keras/keras.json`. If you never | ||
set it, then it will be `"channels_last"`. | ||
Returns: | ||
A NumPy array of the images with the segmentation masks overlaid. | ||
Raises: | ||
ValueError: If the input `images` is not a 4D tensor or NumPy array. | ||
TypeError: If the input `segmentation_masks` is not an integer type. | ||
""" | ||
data_format = data_format or backend.image_data_format() | ||
images_shape = ops.shape(images) | ||
if len(images_shape) != 4: | ||
raise ValueError( | ||
"`images` must be batched 4D tensor. " | ||
f"Received: images.shape={images_shape}" | ||
) | ||
if data_format == "channels_first": | ||
images = ops.transpose(images, (0, 2, 3, 1)) | ||
segmentation_masks = ops.transpose(segmentation_masks, (0, 2, 3, 1)) | ||
images = ops.convert_to_tensor(images, dtype="float32") | ||
segmentation_masks = ops.convert_to_tensor(segmentation_masks) | ||
|
||
if not backend.is_int_dtype(segmentation_masks.dtype): | ||
dtype = backend.standardize_dtype(segmentation_masks.dtype) | ||
raise TypeError( | ||
"`segmentation_masks` must be in integer dtype. " | ||
f"Received: segmentation_masks.dtype={dtype}" | ||
) | ||
|
||
# Infer num_classes | ||
if num_classes is None: | ||
num_classes = int(ops.convert_to_numpy(ops.max(segmentation_masks))) | ||
if color_mapping is None: | ||
colors = _generate_color_palette(num_classes) | ||
else: | ||
colors = [color_mapping[i] for i in range(num_classes)] | ||
valid_masks = ops.not_equal(segmentation_masks, ignore_index) | ||
valid_masks = ops.squeeze(valid_masks, axis=-1) | ||
segmentation_masks = ops.one_hot(segmentation_masks, num_classes) | ||
segmentation_masks = segmentation_masks[..., 0, :] | ||
segmentation_masks = ops.convert_to_numpy(segmentation_masks) | ||
|
||
# Replace class with color | ||
masks = segmentation_masks | ||
masks = np.transpose(masks, axes=(3, 0, 1, 2)).astype("bool") | ||
images_to_draw = ops.convert_to_numpy(images).copy() | ||
for mask, color in zip(masks, colors): | ||
color = np.array(color, dtype=images_to_draw.dtype) | ||
images_to_draw[mask, ...] = color[None, :] | ||
images_to_draw = ops.convert_to_tensor(images_to_draw) | ||
outputs = ops.cast(images_to_draw, dtype="float32") | ||
|
||
if blend: | ||
outputs = images * (1 - alpha) + outputs * alpha | ||
outputs = ops.where(valid_masks[..., None], outputs, images) | ||
outputs = ops.cast(outputs, dtype="uint8") | ||
outputs = ops.convert_to_numpy(outputs) | ||
return outputs | ||
|
||
|
||
def _generate_color_palette(num_classes: int): | ||
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1]) | ||
return [((i * palette) % 255).tolist() for i in range(num_classes)] |
Oops, something went wrong.