Skip to content

Commit

Permalink
[Visualization utils] Add visualization utils for plotting images(pla…
Browse files Browse the repository at this point in the history
…in, with bounding boxes and segmentation masks) (#20401)

* api gen

* add plot image gallery function

* add `plot_ bounding_box_gallery`

* correct label key

* add segmentation mask draw and plot functions

* few arg corrections and docstrings

* nit

* add missing args for plotting segmenation masks use cols for each mask to make aspect ratio of each subplot correct

* add missing argument for color
  • Loading branch information
sineeli authored Oct 24, 2024
1 parent d4bb8e3 commit 56eaab3
Show file tree
Hide file tree
Showing 12 changed files with 777 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from keras.api import tree
from keras.api import utils
from keras.api import version
from keras.api import visualization

# END DO NOT EDIT.

Expand Down
1 change: 1 addition & 0 deletions keras/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from keras.api import saving
from keras.api import tree
from keras.api import utils
from keras.api import visualization
from keras.src.backend import Variable
from keras.src.backend import device
from keras.src.backend import name_scope
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from keras.api import regularizers
from keras.api import tree
from keras.api import utils
from keras.api import visualization
from keras.api._tf_keras.keras import backend
from keras.api._tf_keras.keras import layers
from keras.api._tf_keras.keras import losses
Expand Down
17 changes: 17 additions & 0 deletions keras/api/_tf_keras/keras/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""DO NOT EDIT.
This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes
from keras.src.visualization.draw_segmentation_masks import (
draw_segmentation_masks,
)
from keras.src.visualization.plot_bounding_box_gallery import (
plot_bounding_box_gallery,
)
from keras.src.visualization.plot_image_gallery import plot_image_gallery
from keras.src.visualization.plot_segmentation_mask_gallery import (
plot_segmentation_mask_gallery,
)
17 changes: 17 additions & 0 deletions keras/api/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""DO NOT EDIT.
This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes
from keras.src.visualization.draw_segmentation_masks import (
draw_segmentation_masks,
)
from keras.src.visualization.plot_bounding_box_gallery import (
plot_bounding_box_gallery,
)
from keras.src.visualization.plot_image_gallery import plot_image_gallery
from keras.src.visualization.plot_segmentation_mask_gallery import (
plot_segmentation_mask_gallery,
)
1 change: 1 addition & 0 deletions keras/src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from keras.src import optimizers
from keras.src import regularizers
from keras.src import utils
from keras.src import visualization
from keras.src.backend import KerasTensor
from keras.src.layers import Input
from keras.src.layers import Layer
Expand Down
2 changes: 2 additions & 0 deletions keras/src/visualization/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from keras.src.visualization import draw_bounding_boxes
from keras.src.visualization import plot_image_gallery
177 changes: 177 additions & 0 deletions keras/src/visualization/draw_bounding_boxes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import ( # noqa: E501
convert_format,
)

try:
import cv2
except ImportError:
cv2 = None


@keras_export("keras.visualization.draw_bounding_boxes")
def draw_bounding_boxes(
images,
bounding_boxes,
bounding_box_format,
class_mapping=None,
color=(128, 128, 128),
line_thickness=2,
text_thickness=1,
font_scale=1.0,
data_format=None,
):
"""Draws bounding boxes on images.
This function draws bounding boxes on a batch of images. It supports
different bounding box formats and can optionally display class labels
and confidences.
Args:
images: A batch of images as a 4D tensor or NumPy array. Shape should be
`(batch_size, height, width, channels)`.
bounding_boxes: A dictionary containing bounding box data. Should have
the following keys:
- `boxes`: A tensor or array of shape `(batch_size, num_boxes, 4)`
containing the bounding box coordinates in the specified format.
- `labels`: A tensor or array of shape `(batch_size, num_boxes)`
containing the class labels for each bounding box.
- `confidences` (Optional): A tensor or array of shape
`(batch_size, num_boxes)` containing the confidence scores for
each bounding box.
bounding_box_format: A string specifying the format of the bounding
boxes. Refer [keras-io](TODO)
class_mapping: A dictionary mapping class IDs (integers) to class labels
(strings). Used to display class labels next to the bounding boxes.
Defaults to None (no labels displayed).
color: A tuple or list representing the RGB color of the bounding boxes.
For example, `(255, 0, 0)` for red. Defaults to `(128, 128, 128)`.
line_thickness: An integer specifying the thickness of the bounding box
lines. Defaults to `2`.
text_thickness: An integer specifying the thickness of the text labels.
Defaults to `1`.
font_scale: A float specifying the scale of the font used for text
labels. Defaults to `1.0`.
data_format: A string, either `"channels_last"` or `"channels_first"`,
specifying the order of dimensions in the input images. Defaults to
the `image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
"channels_last".
Returns:
A NumPy array of the annotated images with the bounding boxes drawn.
The array will have the same shape as the input `images`.
Raises:
ValueError: If `images` is not a 4D tensor/array, if `bounding_boxes` is
not a dictionary, or if `bounding_boxes` does not contain `"boxes"`
and `"labels"` keys.
TypeError: If `bounding_boxes` is not a dictionary.
ImportError: If `cv2` (OpenCV) is not installed.
"""

if cv2 is None:
raise ImportError(
"The `draw_bounding_boxes` function requires the `cv2` package "
" (OpenCV). Please install it with `pip install opencv-python`."
)

class_mapping = class_mapping or {}
text_thickness = (
text_thickness or line_thickness
) # Default text_thickness if not provided.
data_format = data_format or backend.image_data_format()
images_shape = ops.shape(images)
if len(images_shape) != 4:
raise ValueError(
"`images` must be batched 4D tensor. "
f"Received: images.shape={images_shape}"
)
if not isinstance(bounding_boxes, dict):
raise TypeError(
"`bounding_boxes` should be a dict. "
f"Received: bounding_boxes={bounding_boxes} of type "
f"{type(bounding_boxes)}"
)
if "boxes" not in bounding_boxes or "labels" not in bounding_boxes:
raise ValueError(
"`bounding_boxes` should be a dict containing 'boxes' and "
f"'labels' keys. Received: bounding_boxes={bounding_boxes}"
)
if data_format == "channels_last":
h_axis = -3
w_axis = -2
else:
h_axis = -2
w_axis = -1
height = images_shape[h_axis]
width = images_shape[w_axis]
bounding_boxes = bounding_boxes.copy()
bounding_boxes = convert_format(
bounding_boxes, bounding_box_format, "xyxy", height, width
)

# To numpy array
images = ops.convert_to_numpy(images).astype("uint8")
boxes = ops.convert_to_numpy(bounding_boxes["boxes"])
labels = ops.convert_to_numpy(bounding_boxes["labels"])
if "confidences" in bounding_boxes:
confidences = ops.convert_to_numpy(bounding_boxes["confidences"])
else:
confidences = None

result = []
batch_size = images.shape[0]
for i in range(batch_size):
_image = images[i]
_box = boxes[i]
_class = labels[i]
for box_i in range(_box.shape[0]):
x1, y1, x2, y2 = _box[box_i].astype("int32")
c = _class[box_i].astype("int32")
if c == -1:
continue
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
c = int(c)
# Draw bounding box
cv2.rectangle(_image, (x1, y1), (x2, y2), color, line_thickness)

if c in class_mapping:
label = class_mapping[c]
if confidences is not None:
conf = confidences[i][box_i]
label = f"{label} | {conf:.2f}"

font_x1, font_y1 = _find_text_location(
x1, y1, font_scale, text_thickness
)
cv2.putText(
img=_image,
text=label,
org=(font_x1, font_y1),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=font_scale,
color=color,
thickness=text_thickness,
)
result.append(_image)
return np.stack(result, axis=0)


def _find_text_location(x, y, font_scale, thickness):
font_height = int(font_scale * 12)
target_y = y - 8
if target_y - (2 * font_height) > 0:
return x, y - 8

line_offset = thickness
static_offset = 3

return (
x + static_offset,
y + (2 * font_height) + line_offset + static_offset,
)
109 changes: 109 additions & 0 deletions keras/src/visualization/draw_segmentation_masks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import numpy as np

from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export


@keras_export("keras.visualization.draw_segmentation_masks")
def draw_segmentation_masks(
images,
segmentation_masks,
num_classes=None,
color_mapping=None,
alpha=0.8,
blend=True,
ignore_index=-1,
data_format=None,
):
"""Draws segmentation masks on images.
The function overlays segmentation masks on the input images.
The masks are blended with the images using the specified alpha value.
Args:
images: A batch of images as a 4D tensor or NumPy array. Shape
should be (batch_size, height, width, channels).
segmentation_masks: A batch of segmentation masks as a 3D or 4D tensor
or NumPy array. Shape should be (batch_size, height, width) or
(batch_size, height, width, 1). The values represent class indices
starting from 1 up to `num_classes`. Class 0 is reserved for
the background and will be ignored if `ignore_index` is not 0.
num_classes: The number of segmentation classes. If `None`, it is
inferred from the maximum value in `segmentation_masks`.
color_mapping: A dictionary mapping class indices to RGB colors.
If `None`, a default color palette is generated. The keys should be
integers starting from 1 up to `num_classes`.
alpha: The opacity of the segmentation masks. Must be in the range
`[0, 1]`.
blend: Whether to blend the masks with the input image using the
`alpha` value. If `False`, the masks are drawn directly on the
images without blending. Defaults to `True`.
ignore_index: The class index to ignore. Mask pixels with this value
will not be drawn. Defaults to -1.
data_format: Image data format, either `"channels_last"` or
`"channels_first"`. Defaults to the `image_data_format` value found
in your Keras config file at `~/.keras/keras.json`. If you never
set it, then it will be `"channels_last"`.
Returns:
A NumPy array of the images with the segmentation masks overlaid.
Raises:
ValueError: If the input `images` is not a 4D tensor or NumPy array.
TypeError: If the input `segmentation_masks` is not an integer type.
"""
data_format = data_format or backend.image_data_format()
images_shape = ops.shape(images)
if len(images_shape) != 4:
raise ValueError(
"`images` must be batched 4D tensor. "
f"Received: images.shape={images_shape}"
)
if data_format == "channels_first":
images = ops.transpose(images, (0, 2, 3, 1))
segmentation_masks = ops.transpose(segmentation_masks, (0, 2, 3, 1))
images = ops.convert_to_tensor(images, dtype="float32")
segmentation_masks = ops.convert_to_tensor(segmentation_masks)

if not backend.is_int_dtype(segmentation_masks.dtype):
dtype = backend.standardize_dtype(segmentation_masks.dtype)
raise TypeError(
"`segmentation_masks` must be in integer dtype. "
f"Received: segmentation_masks.dtype={dtype}"
)

# Infer num_classes
if num_classes is None:
num_classes = int(ops.convert_to_numpy(ops.max(segmentation_masks)))
if color_mapping is None:
colors = _generate_color_palette(num_classes)
else:
colors = [color_mapping[i] for i in range(num_classes)]
valid_masks = ops.not_equal(segmentation_masks, ignore_index)
valid_masks = ops.squeeze(valid_masks, axis=-1)
segmentation_masks = ops.one_hot(segmentation_masks, num_classes)
segmentation_masks = segmentation_masks[..., 0, :]
segmentation_masks = ops.convert_to_numpy(segmentation_masks)

# Replace class with color
masks = segmentation_masks
masks = np.transpose(masks, axes=(3, 0, 1, 2)).astype("bool")
images_to_draw = ops.convert_to_numpy(images).copy()
for mask, color in zip(masks, colors):
color = np.array(color, dtype=images_to_draw.dtype)
images_to_draw[mask, ...] = color[None, :]
images_to_draw = ops.convert_to_tensor(images_to_draw)
outputs = ops.cast(images_to_draw, dtype="float32")

if blend:
outputs = images * (1 - alpha) + outputs * alpha
outputs = ops.where(valid_masks[..., None], outputs, images)
outputs = ops.cast(outputs, dtype="uint8")
outputs = ops.convert_to_numpy(outputs)
return outputs


def _generate_color_palette(num_classes: int):
palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])
return [((i * palette) % 255).tolist() for i in range(num_classes)]
Loading

0 comments on commit 56eaab3

Please sign in to comment.