Skip to content

Commit

Permalink
change channel names (#786)
Browse files Browse the repository at this point in the history
* change channel names

* write channel meta

* add tests

* add to docstring

* adjust to multiscale_spatial_image

* refactor

* adjust test

* correct refactor and test

* add get_channel_names

* local import get_model

* import set_channel_names

* update log and api.md

* local import
  • Loading branch information
melonora authored Nov 22, 2024
1 parent 62e4699 commit 94f0a31
Show file tree
Hide file tree
Showing 15 changed files with 260 additions and 26 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@ and this project adheres to [Semantic Versioning][].

## [0.2.6] - TBD

### Added

- Added `set_channel_names` method to `SpatialData` to change the channel names of an
image element in `SpatialData`
- Added `write_channel_names` method to `SpatialData` to overwrite channel metadata on disk
without overwriting the image array itself.

### Changed

- `get_channels` is marked for deprecation in `SpatialData` v0.3.0. Function is replaced
by `get_channel_names`

### Fixed

- Updated deprecated default stages of `pre-commit` #771
Expand Down
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ The elements (building-blocks) that constitute `SpatialData`.
points_geopandas_to_dask_dataframe
points_dask_dataframe_to_geopandas
get_channels
get_channel_names
set_channel_names
force_2d
```

Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/operations/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dask.array.overlap import coerce_depth
from xarray import DataArray, DataTree

from spatialdata.models._utils import get_axes_names, get_channels, get_raster_model_from_data_dims
from spatialdata.models._utils import get_axes_names, get_channel_names, get_raster_model_from_data_dims
from spatialdata.transformations import get_transformation

__all__ = ["map_raster"]
Expand Down Expand Up @@ -121,7 +121,7 @@ def map_raster(

if "c" in dims:
if c_coords is None:
c_coords = range(arr.shape[0]) if arr.shape[0] != len(get_channels(data)) else get_channels(data)
c_coords = range(arr.shape[0]) if arr.shape[0] != len(get_channel_names(data)) else get_channel_names(data)
else:
c_coords = None
if transformations is None:
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_core/operations/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata.models import SpatialElement, get_axes_names, get_model
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channels
from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM, get_channel_names
from spatialdata.transformations._utils import _get_scale, compute_coordinates, scale_radii

if TYPE_CHECKING:
Expand Down Expand Up @@ -367,7 +367,7 @@ def _(
channel_names = None
elif schema in (Image2DModel, Image3DModel):
kwargs = {}
channel_names = get_channels(data)
channel_names = get_channel_names(data)
else:
raise ValueError(f"DataTree with schema {schema} not supported")

Expand Down
74 changes: 71 additions & 3 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables
from spatialdata._logging import logger
from spatialdata._types import ArrayLike, Raster_T
from spatialdata._utils import _deprecation_alias, _error_message_add_element
from spatialdata._utils import (
_deprecation_alias,
_error_message_add_element,
)
from spatialdata.models import (
Image2DModel,
Image3DModel,
Expand All @@ -36,7 +39,12 @@
get_model,
get_table_keys,
)
from spatialdata.models._utils import SpatialElement, convert_region_column_to_categorical, get_axes_names
from spatialdata.models._utils import (
SpatialElement,
convert_region_column_to_categorical,
get_axes_names,
set_channel_names,
)

if TYPE_CHECKING:
from spatialdata._core.query.spatial_query import BaseSpatialRequest
Expand Down Expand Up @@ -315,6 +323,26 @@ def get_instance_key_column(table: AnnData) -> pd.Series:
return table.obs[instance_key]
raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.")

def set_channel_names(self, element_name: str, channel_names: str | list[str], write: bool = False) -> None:
"""Set the channel names for a image `SpatialElement` in the `SpatialData` object.
This method assumes that the `SpatialData` object and the element are already stored on disk as it will
also overwrite the channel names metadata on disk. In case either the `SpatialData` object or the
element are not stored on disk, please use `SpatialData.set_image_channel_names` instead.
Parameters
----------
element_name
Name of the image `SpatialElement`.
channel_names
The channel names to be assigned to the c dimension of the image `SpatialElement`.
write
Whether to overwrite the channel metadata on disk.
"""
self.images[element_name] = set_channel_names(self.images[element_name], channel_names)
if write:
self.write_channel_names(element_name)

@staticmethod
def _set_table_annotation_target(
table: AnnData,
Expand Down Expand Up @@ -1441,6 +1469,45 @@ def _validate_can_write_metadata_on_element(self, element_name: str) -> tuple[st
)
return element_type, element

def write_channel_names(self, element_name: str | None = None) -> None:
"""
Write channel names to disk for a single image element, or for all image elements, without rewriting the data.
Parameters
----------
element_name
The name of the element to write the channel names of. If None, write the channel names of all image
elements.
"""
from spatialdata._core._elements import Elements

if element_name is not None:
Elements._check_valid_name(element_name)

# recursively write the transformation for all the SpatialElement
if element_name is None:
for element_name in list(self.images.keys()):
self.write_channel_names(element_name)
return

validation_result = self._validate_can_write_metadata_on_element(element_name)
if validation_result is None:
return

element_type, element = validation_result

# Mypy does not understand that path is not None so we have the check in the conditional
if element_type == "images" and self.path is not None:
_, _, element_group = self._get_groups_for_element(
zarr_path=Path(self.path), element_type=element_type, element_name=element_name
)

from spatialdata._io._utils import overwrite_channel_names

overwrite_channel_names(element_group, element)
else:
raise ValueError(f"Can't set channel names for element of type '{element_type}'.")

def write_transformations(self, element_name: str | None = None) -> None:
"""
Write transformations to disk for a single element, or for all elements, without rewriting the data.
Expand Down Expand Up @@ -1471,6 +1538,7 @@ def write_transformations(self, element_name: str | None = None) -> None:
transformations = get_transformation(element, get_all=True)
assert isinstance(transformations, dict)

# Mypy does not understand that path is not None so we have a conditional
assert self.path is not None
_, _, element_group = self._get_groups_for_element(
zarr_path=Path(self.path), element_type=element_type, element_name=element_name
Expand Down Expand Up @@ -1546,9 +1614,9 @@ def write_metadata(self, element_name: str | None = None, consolidate_metadata:
Elements._check_valid_name(element_name)

self.write_transformations(element_name)
self.write_channel_names(element_name)
# TODO: write .uns['spatialdata_attrs'] metadata for AnnData.
# TODO: write .attrs['spatialdata_attrs'] metadata for DaskDataFrame.
# TODO: write omero metadata for the channel name of images.

if consolidate_metadata is None and self.has_consolidated_metadata():
consolidate_metadata = True
Expand Down
21 changes: 21 additions & 0 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,27 @@ def overwrite_coordinate_transformations_raster(
group.attrs["multiscales"] = multiscales


def overwrite_channel_names(group: zarr.Group, element: DataArray | DataTree) -> None:
"""Write channel metadata to a group."""
if isinstance(element, DataArray):
channel_names = element.coords["c"].data.tolist()
else:
channel_names = element["scale0"]["image"].coords["c"].data.tolist()

channel_metadata = [{"label": name} for name in channel_names]
omero_meta = group.attrs["omero"]
omero_meta["channels"] = channel_metadata
group.attrs["omero"] = omero_meta
multiscales_meta = group.attrs["multiscales"]
if len(multiscales_meta) != 1:
raise ValueError(
f"Multiscale metadata must be of length one but got length {len(multiscales_meta)}. Data might"
f"be corrupted."
)
multiscales_meta[0]["metadata"]["omero"]["channels"] = channel_metadata
group.attrs["multiscales"] = multiscales_meta


def _write_metadata(
group: zarr.Group,
group_type: str,
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_io/io_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_parse_version,
)
from spatialdata._utils import get_pyramid_levels
from spatialdata.models._utils import get_channels
from spatialdata.models._utils import get_channel_names
from spatialdata.models.models import ATTRS_KEY
from spatialdata.transformations._utils import (
_get_transformations,
Expand Down Expand Up @@ -151,7 +151,7 @@ def _get_group_for_writing_transformations() -> zarr.Group:
# convert channel names to channel metadata in omero
if raster_type == "image":
metadata["metadata"] = {"omero": {"channels": []}}
channels = get_channels(raster_data)
channels = get_channel_names(raster_data)
for c in channels:
metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload]

Expand Down
35 changes: 35 additions & 0 deletions src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandas as pd
from anndata import AnnData
from dask import array as da
from dask.array import Array as DaskArray
from xarray import DataArray, Dataset, DataTree

from spatialdata._types import ArrayLike
Expand Down Expand Up @@ -311,3 +312,37 @@ def _error_message_add_element() -> None:
"write_labels(), write_points(), write_shapes() and write_table(). We are going to make these calls more "
"ergonomic in a follow up PR."
)


def _check_match_length_channels_c_dim(
data: DaskArray | DataArray | DataTree, c_coords: str | list[str], dims: tuple[str]
) -> list[str]:
"""
Check whether channel names `c_coords` are of equal length to the `c` dimension of the data.
Parameters
----------
data
The image array
c_coords
The channel names
dims
The axes names in the order that is the same as the `ImageModel` from which it is derived.
Returns
-------
c_coords
The channel names as list
"""
c_index = dims.index("c")
c_length = (
data.shape[c_index] if isinstance(data, DataArray | DaskArray) else data["scale0"]["image"].shape[c_index]
)
if isinstance(c_coords, str):
c_coords = [c_coords]
if c_coords is not None and len(c_coords) != c_length:
raise ValueError(
f"The number of channel names `{len(c_coords)}` does not match the length of dimension 'c'"
f" with length {c_length}."
)
return c_coords
4 changes: 4 additions & 0 deletions src/spatialdata/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
Z,
force_2d,
get_axes_names,
get_channel_names,
get_channels,
get_spatial_axes,
points_dask_dataframe_to_geopandas,
points_geopandas_to_dask_dataframe,
set_channel_names,
validate_axes,
validate_axis_name,
)
Expand Down Expand Up @@ -49,6 +51,8 @@
"check_target_region_column_symmetry",
"get_table_keys",
"get_channels",
"get_channel_names",
"set_channel_names",
"force_2d",
"RasterSchema",
]
68 changes: 65 additions & 3 deletions src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from xarray import DataArray, DataTree

from spatialdata._logging import logger
from spatialdata._utils import _check_match_length_channels_c_dim
from spatialdata.transformations.transformations import BaseTransformation

SpatialElement: TypeAlias = DataArray | DataTree | GeoDataFrame | DaskDataFrame
Expand Down Expand Up @@ -268,7 +269,7 @@ def points_geopandas_to_dask_dataframe(gdf: GeoDataFrame, suppress_z_warning: bo


@singledispatch
def get_channels(data: Any) -> list[Any]:
def get_channel_names(data: Any) -> list[Any]:
"""Get channels from data for an image element (both single and multiscale).
Parameters
Expand All @@ -287,12 +288,40 @@ def get_channels(data: Any) -> list[Any]:
raise ValueError(f"Cannot get channels from {type(data)}")


@get_channels.register
def get_channels(data: Any) -> list[Any]:
"""Get channels from data for an image element (both single and multiscale).
[Deprecation] This function will be deprecated in version 0.3.0. Please use
`get_channel_names`.
Parameters
----------
data
data to get channels from
Returns
-------
List of channels
Notes
-----
For multiscale images, the channels are validated to be consistent across scales.
"""
warnings.warn(
"The function 'get_channels' is deprecated and will be removed in version 0.3.0. "
"Please use 'get_channel_names' instead.",
DeprecationWarning,
stacklevel=2, # Adjust the stack level to point to the caller
)
return get_channel_names(data)


@get_channel_names.register
def _(data: DataArray) -> list[Any]:
return data.coords["c"].values.tolist() # type: ignore[no-any-return]


@get_channels.register
@get_channel_names.register
def _(data: DataTree) -> list[Any]:
name = list({list(data[i].data_vars.keys())[0] for i in data})[0]
channels = {tuple(data[i][name].coords["c"].values) for i in data}
Expand Down Expand Up @@ -374,3 +403,36 @@ def convert_region_column_to_categorical(table: AnnData) -> AnnData:
)
table.obs[region_key] = pd.Categorical(table.obs[region_key])
return table


def set_channel_names(element: DataArray | DataTree, channel_names: str | list[str]) -> DataArray | DataTree:
"""Set the channel names for a image `SpatialElement` in the `SpatialData` object.
Parameters
----------
element
The image `SpatialElement` or parsed `ImageModel`.
channel_names
The channel names to be assigned to the c dimension of the image `SpatialElement`.
Returns
-------
element
The image `SpatialElement` or parsed `ImageModel` with the channel names set to the `c` dimension.
"""
from spatialdata.models import Image2DModel, Image3DModel, get_model

channel_names = channel_names if isinstance(channel_names, list) else [channel_names]
model = get_model(element)

# get_model cannot be used due to circular import so get_axes_names is used instead
if model in [Image2DModel, Image3DModel]:
channel_names = _check_match_length_channels_c_dim(element, channel_names, model.dims.dims) # type: ignore[union-attr]
if isinstance(element, DataArray):
element = element.assign_coords(c=channel_names)
else:
element = element.msi.assign_coords({"c": channel_names})
else:
raise TypeError("Element model does not support setting channel names, no `c` dimension found.")

return element
Loading

0 comments on commit 94f0a31

Please sign in to comment.