diff --git a/muon/_core/plot.py b/muon/_core/plot.py index 752518c..52be000 100644 --- a/muon/_core/plot.py +++ b/muon/_core/plot.py @@ -1,4 +1,4 @@ -from typing import Union, List, Optional, Iterable, Sequence, Dict +from typing import Dict, Iterable, List, Optional, Sequence, Union import warnings from matplotlib.axes import Axes @@ -22,7 +22,7 @@ def scatter( data: Union[AnnData, MuData], x: Optional[str] = None, y: Optional[str] = None, - color: Optional[str] = None, + color: Optional[Union[str, Sequence[str]]] = None, use_raw: Optional[bool] = None, layers: Optional[Union[str, Sequence[str]]] = None, **kwargs, @@ -42,8 +42,8 @@ def scatter( x coordinate y : Optional[str] y coordinate - color : Optional[str], optional (default: None) - Key for variables or annotations of observations (.obs columns), + color : Optional[Union[str, Sequence[str]]], optional (default: None) + Keys or a single key for variables or annotations of observations (.obs columns), or a hex colour specification. use_raw : Optional[bool], optional (default: None) Use `.raw` attribute of the modality where a feature (from `color`) is derived from. @@ -71,7 +71,7 @@ def scatter( color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) color_obs = pd.DataFrame({color: color_obs}) else: - raise TypeError("Expected color to be a string.") + color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) color_obs.index = data.obs_names obs = pd.concat([obs, color_obs], axis=1, ignore_index=False)