Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

polygon_query() support for images #358

Merged
merged 3 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,14 @@ def _(


def _polygon_query(
sdata: SpatialData, polygon: Polygon, target_coordinate_system: str, filter_table: bool, shapes: bool, points: bool
sdata: SpatialData,
polygon: Polygon,
target_coordinate_system: str,
filter_table: bool,
shapes: bool,
points: bool,
images: bool,
labels: bool,
) -> SpatialData:
from spatialdata._core.query._utils import circles_to_polygons
from spatialdata._core.query.relational_query import _filter_table_by_elements
Expand Down Expand Up @@ -669,11 +676,32 @@ def _polygon_query(
set_transformation(ddf, transformation, target_coordinate_system)
new_points[points_name] = ddf

if filter_table:
new_images = {}
if images:
for images_name, im in sdata.images.items():
min_x, min_y, max_x, max_y = polygon.bounds
cropped = bounding_box_query(
im,
min_coordinate=[min_x, min_y],
max_coordinate=[max_x, max_y],
axes=("x", "y"),
target_coordinate_system=target_coordinate_system,
)
new_images[images_name] = cropped
if labels:
for labels_name, l in sdata.labels.items():
_ = labels_name
_ = l
raise NotImplementedError(
"labels=True is not implemented yet. If you encounter this error please open an "
"issue and we will prioritize the implementation."
)

if filter_table and sdata.table is not None:
table = _filter_table_by_elements(sdata.table, {"shapes": new_shapes, "points": new_points})
else:
table = sdata.table
return SpatialData(shapes=new_shapes, points=new_points, table=table)
return SpatialData(shapes=new_shapes, points=new_points, images=new_images, table=table)


# this function is currently excluded from the API documentation. TODO: add it after the refactoring
Expand All @@ -684,6 +712,8 @@ def polygon_query(
filter_table: bool = True,
shapes: bool = True,
points: bool = True,
images: bool = True,
labels: bool = True,
) -> SpatialData:
"""
Query a spatial data object by a polygon, filtering shapes and points.
Expand Down Expand Up @@ -725,14 +755,21 @@ def polygon_query(
filter_table=filter_table,
shapes=shapes,
points=points,
images=images,
labels=labels,
)
# TODO: the performance for this case can be greatly improved by using the geopandas queries only once, and not
# in a loop as done preliminarily here
if points:
raise NotImplementedError(
"points=True is not implemented when querying by multiple polygons. If you encounter this error, please"
" open an issue on GitHub and we will prioritize the implementation."
if points or images or labels:
logger.warning(
"Spatial querying of images, points and labels is not implemented when querying by multiple polygons "
'simultaneously. You can silence this warning by setting "points=False, images=False, labels=False". If '
"you need this implementation please open an issue on GitHub and we will prioritize the implementation."
LucaMarconato marked this conversation as resolved.
Show resolved Hide resolved
)
points = False
images = False
labels = False

sdatas = []
for polygon in tqdm(polygons):
try:
Expand All @@ -744,6 +781,8 @@ def polygon_query(
filter_table=False,
shapes=shapes,
points=points,
images=images,
labels=labels,
)
sdatas.append(queried_sdata)
except ValueError as e:
Expand Down
40 changes: 34 additions & 6 deletions tests/core/query/test_spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from anndata import AnnData
from multiscale_spatial_image import MultiscaleSpatialImage
from shapely import Polygon
from spatial_image import SpatialImage
from spatialdata import SpatialData
from spatialdata._core.query.spatial_query import (
Expand Down Expand Up @@ -379,7 +380,11 @@ def test_polygon_query_shapes(sdata_query_aggregation):
circle_pol = circle.buffer(sdata["by_circles"].radius.iloc[0])

queried = polygon_query(
values_sdata, polygons=polygon, target_coordinate_system="global", shapes=True, points=False
values_sdata,
polygons=polygon,
target_coordinate_system="global",
shapes=True,
points=False,
)
assert len(queried["values_polygons"]) == 4
assert len(queried["values_circles"]) == 4
Expand Down Expand Up @@ -432,11 +437,34 @@ def test_polygon_query_spatial_data(sdata_query_aggregation):
assert len(queried.table) == 8


@pytest.mark.skip
def test_polygon_query_image2d():
# single image case
# multiscale case
pass
@pytest.mark.parametrize("n_channels", [1, 2, 3])
def test_polygon_query_image2d(n_channels: int):
original_image = np.zeros((n_channels, 10, 10))
# y: [5, 9], x: [0, 4] has value 1
original_image[:, 5::, 0:5] = 1
image_element = Image2DModel.parse(original_image)
image_element_multiscale = Image2DModel.parse(original_image, scale_factors=[2, 2])

polygon = Polygon([(3, 3), (3, 7), (5, 3)])
for image in [image_element, image_element_multiscale]:
# bounding box: y: [5, 10[, x: [0, 5[
image_result = polygon_query(
SpatialData(images={"my_image": image}),
polygons=polygon,
target_coordinate_system="global",
)["my_image"]
expected_image = original_image[:, 3:7, 3:5] # c dimension is preserved
if isinstance(image, SpatialImage):
assert isinstance(image, SpatialImage)
np.testing.assert_allclose(image_result, expected_image)
elif isinstance(image, MultiscaleSpatialImage):
assert isinstance(image_result, MultiscaleSpatialImage)
v = image_result["scale0"].values()
assert len(v) == 1
xdata = v.__iter__().__next__()
np.testing.assert_allclose(xdata, expected_image)
else:
raise ValueError("Unexpected type")


@pytest.mark.skip
Expand Down
Loading