Skip to content

Commit

Permalink
Merge pull request #148 from scverse/ig/fix_filter
Browse files Browse the repository at this point in the history
Fix in-place filtering
  • Loading branch information
gtca authored Oct 16, 2024
2 parents 9fd69c7 + 10060b5 commit c7461aa
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 125 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so

# cached data
data/

# Distribution / packaging
.Python
build/
Expand Down
237 changes: 114 additions & 123 deletions muon/_core/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,23 +657,24 @@ def intersect_obs(mdata: MuData):
return


# Utility functions: filtering observations
# Utility functions: filtering observations or variables


def filter_obs(
data: Union[AnnData, MuData], var: Union[str, Sequence[str]], func: Optional[Callable] = None
def _filter_attr(
data: Union[AnnData, MuData],
attr: Literal["obs", "var"],
key: Union[str, Sequence[str]],
func: Optional[Callable] = None,
) -> None:
"""
Filter observations (samples or cells) in-place
using any column in .obs or in .X.
Filter observations or variables in-place.
Parameters
----------
data: AnnData or MuData
AnnData or MuData object
var: str or Sequence[str]
Column name in .obs or in .X to be used for filtering.
Alternatively, obs_names can be provided directly.
key: str or Sequence[str]
Names or key to filter
func
Function to apply to the variable used for filtering.
If the variable is of type boolean and func is an identity function,
Expand All @@ -694,51 +695,76 @@ def filter_obs(
"MuData object is backed. The requested subset of the .X matrices of its modalities will be read into memory, and the object will not be backed anymore."
)

if isinstance(var, str):
if var in data.obs.columns:
assert attr in ("obs", "var"), "Attribute has to be either 'obs' or 'var'."

df = getattr(data, attr)
names = getattr(data, f"{attr}_names")
other = "obs" if attr == "var" else "var"
other_names = getattr(data, f"{other}_names")
attrm = getattr(data, f"{attr}m")
attrp = getattr(data, f"{attr}p")

if isinstance(key, str):
if key in df.columns:
if func is None:
if data.obs[var].dtypes.name == "bool":
if df[key].dtypes.name == "bool":

def func(x):
return x

else:
raise ValueError(f"Function has to be provided since {var} is not boolean")
obs_subset = func(data.obs[var].values)
elif var in data.var_names:
obs_subset = func(data.X[:, np.where(data.var_names == var)[0]].reshape(-1))
raise ValueError(f"Function has to be provided since {key} is not boolean")
subset = func(df[key].values)
elif key in other_names:
if attr == "obs":
subset = func(data.X[:, np.where(other_names == key)[0]].reshape(-1))
else:
subset = func(data.X[np.where(other_names == key)[0], :].reshape(-1))
else:
raise ValueError(
f"Column name from .obs or one of the var_names was expected but got {var}."
f"Column name from .{attr} or one of the {other}_names was expected but got {key}."
)
else:
if func is None:
if np.array(var).dtype == bool:
obs_subset = np.array(var)
if np.array(key).dtype == bool:
subset = np.array(key)
else:
obs_subset = data.obs_names.isin(var)
subset = names.isin(key)
else:
raise ValueError("When providing obs_names directly, func has to be None.")
raise ValueError(f"When providing {attr}_names directly, func has to be None.")

# Subset .obs
data._obs = data.obs[obs_subset]
data._n_obs = data.obs.shape[0]
if isinstance(data, AnnData):
# Collect elements to subset
# NOTE: accessing them after subsetting .obs/.var
# will fail due to _validate_value()
attrm = dict(attrm)
attrp = dict(attrp)

# Subset .obsm
for k, v in data.obsm.items():
data.obsm[k] = v[obs_subset]
# Subset .obs/.var
setattr(data, f"_{attr}", df[subset])

# Subset .obsp
for k, v in data.obsp.items():
data.obsp[k] = v[obs_subset][:, obs_subset]
# Subset .obsm/.varm
for k, v in attrm.items():
attrm[k] = v[subset]
setattr(data, f"{attr}m", attrm)

# Subset .obsp/.obsp
for k, v in attrp.items():
attrp[k] = v[subset][:, subset]
setattr(data, f"{attr}p", attrp)

if isinstance(data, AnnData):
# Subset .X
if data._X is not None:
try:
data._X = data.X[obs_subset, :]
if attr == "obs":
data._X = data.X[subset, :]
else:
data._X = data.X[:, subset]
except TypeError:
data._X = data.X[np.where(obs_subset)[0], :]
if attr == "obs":
data._X = data.X[np.where(subset)[0], :]
else:
data._X = data.X[:, np.where(subset)[0]]
# For some h5py versions, indexing arrays must have integer dtypes
# https://github.com/h5py/h5py/issues/1847

Expand All @@ -748,29 +774,71 @@ def func(x):

# Subset layers
for layer in data.layers:
data.layers[layer] = data.layers[layer][obs_subset, :]
if attr == "obs":
data.layers[layer] = data.layers[layer][subset, :]
else:
data.layers[layer] = data.layers[layer][:, subset]

# Subset raw
if data.raw is not None:
data.raw._X = data.raw.X[obs_subset, :]
data.raw._n_obs = data.raw.X.shape[0]
# Subset raw - only when subsetting obs
if attr == "obs" and data.raw is not None:
data.raw._X = data.raw.X[subset, :]

else:
# filter_obs() for each modality
attrmap = getattr(data, f"{attr}map")

# Subset .obs/.var
setattr(data, f"_{attr}", df[subset])

# Subset .obsm/.varm
for k, v in attrm.items():
attrm[k] = v[subset]
setattr(data, f"{attr}m", attrm)

# Subset .obsp/.varp
for k, v in attrp.items():
attrp[k] = v[subset][:, subset]
setattr(data, f"{attr}p", attrp)

# _filter_attr() for each modality
for m, mod in data.mod.items():
obsmap = data.obsmap[m][obs_subset]
obsidx = obsmap > 0
filter_obs(mod, mod.obs_names[obsmap[obsidx] - 1])
maporder = np.argsort(obsmap[obsidx])
map_subset = attrmap[m][subset]
attridx = map_subset > 0
orig_attr = getattr(mod, attr).copy()
mod_names = getattr(mod, f"{attr}_names")
_filter_attr(mod, attr, mod_names[map_subset[attridx] - 1])
data.mod[m]._remove_unused_categories(orig_attr, getattr(mod, attr), mod.uns)
maporder = np.argsort(map_subset[attridx])
nobsmap = np.empty(maporder.size)
nobsmap[maporder] = np.arange(1, maporder.size + 1)
obsmap[obsidx] = nobsmap
data.obsmap[m] = obsmap
map_subset[attridx] = nobsmap
getattr(data, f"{attr}map")[m] = map_subset

return


# Utility functions: filtering variables
def filter_obs(
data: Union[AnnData, MuData], var: Union[str, Sequence[str]], func: Optional[Callable] = None
) -> None:
"""
Filter observations (samples or cells) in-place
using any column in .obs or in .X.
Parameters
----------
data: AnnData or MuData
AnnData or MuData object
var: str or Sequence[str]
Column name in .obs or in .X to be used for filtering.
Alternatively, obs_names can be provided directly.
func
Function to apply to the variable used for filtering.
If the variable is of type boolean and func is an identity function,
the func argument can be omitted.
"""

_filter_attr(data, "obs", var, func)

return


def filter_var(
Expand All @@ -793,84 +861,7 @@ def filter_var(
the func argument can be omitted.
"""

if data.is_view:
raise ValueError(
"The provided adata is a view. In-place filtering does not operate on views."
)
if data.isbacked:
if isinstance(data, AnnData):
warnings.warn(
"AnnData object is backed. The requested subset of the matrix .X will be read into memory, and the object will not be backed anymore."
)
else:
warnings.warn(
"MuData object is backed. The requested subset of the .X matrices of its modalities will be read into memory, and the object will not be backed anymore."
)

if isinstance(var, str):
if var in data.var.columns:
if func is None:
if data.var[var].dtypes.name == "bool":

def func(x):
return x

else:
raise ValueError(f"Function has to be provided since {var} is not boolean")
var_subset = func(data.var[var].values)
elif var in data.obs_names:
var_subset = func(data.X[:, np.where(data.obs_names == var)[0]].reshape(-1))
else:
raise ValueError(
f"Column name from .var or one of the obs_names was expected but got {var}."
)
else:
if func is None:
var_subset = var if np.array(var).dtype == bool else data.var_names.isin(var)
else:
raise ValueError("When providing var_names directly, func has to be None.")

# Subset .var
data._var = data.var[var_subset]
data._n_vars = data.var.shape[0]

# Subset .varm
for k, v in data.varm.items():
data.varm[k] = v[var_subset]

# Subset .varp
for k, v in data.varp.items():
data.varp[k] = v[var_subset][:, var_subset]

if isinstance(data, AnnData):
# Subset .X
try:
data._X = data.X[:, var_subset]
except TypeError:
data._X = data.X[:, np.where(var_subset)[0]]
# For some h5py versions, indexing arrays must have integer dtypes
# https://github.com/h5py/h5py/issues/1847
if data.isbacked:
data.file.close()
data.filename = None

# Subset layers
for layer in data.layers:
data.layers[layer] = data.layers[layer][:, var_subset]

# NOTE: .raw is not subsetted

else:
# filter_var() for each modality
for m, mod in data.mod.items():
varmap = data.varmap[m][var_subset]
varidx = varmap > 0
filter_var(mod, mod.var_names[varmap[varidx] - 1])
maporder = np.argsort(varmap[varidx])
nvarmap = np.empty(maporder.size)
nvarmap[maporder] = np.arange(1, maporder.size + 1)
varmap[varidx] = nvarmap
data.varmap[m] = varmap
_filter_attr(data, "var", var, func)

return

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ atac = [
test = [
"pytest",
"flake8",
"pytest",
]

[tool.flit.metadata.urls]
Expand Down
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
import scanpy as sc


@pytest.fixture(scope="module")
Expand All @@ -9,3 +10,8 @@ def filepath_h5mu(tmpdir_factory):
@pytest.fixture(scope="module")
def filepath_hdf5(tmpdir_factory):
yield str(tmpdir_factory.mktemp("tmp_mofa_dir").join("mofa_pytest.hdf5"))


@pytest.fixture(scope="module")
def pbmc3k_processed():
yield sc.datasets.pbmc3k_processed()
35 changes: 33 additions & 2 deletions tests/test_muon_preproc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import unittest
import pytest

import os
from functools import reduce

import numpy as np
from scipy.sparse import csr_matrix
from anndata import AnnData
from anndata.tests.helpers import assert_equal
from mudata import MuData
import muon as mu

Expand Down Expand Up @@ -83,6 +82,21 @@ def test_filter_obs_adata_view(self, mdata, filepath_h5mu):
sub = np.random.binomial(1, 0.5, view.n_obs).astype(bool)
mu.pp.filter_obs(view, sub)

def test_filter_obs_with_obsm_obsp(self, pbmc3k_processed):
A = pbmc3k_processed[:500,].copy()
B = pbmc3k_processed[500:,].copy()
A_subset = A[A.obs["louvain"] == "B cells"].copy()
B_subset = B[B.obs["louvain"] == "B cells"].copy()
mdata = mu.MuData({"A": A, "B": B}, axis=1)
mdata.pull_obs("louvain")
mu.pp.filter_obs(mdata, "louvain", lambda x: x == "B cells")
assert mdata["B"].n_obs == B_subset.n_obs
assert mdata["A"].obs["louvain"].unique() == "B cells"
assert B.n_obs == B_subset.n_obs
assert A.obs["louvain"].unique() == "B cells"
assert_equal(mdata["A"], A_subset)
assert_equal(mdata["B"], B_subset)

# Variables

def test_filter_var_adata(self, mdata, filepath_h5mu):
Expand Down Expand Up @@ -132,6 +146,23 @@ def test_filter_var_adata_view(self, mdata, filepath_h5mu):
sub = np.random.binomial(1, 0.5, view.n_vars).astype(bool)
mu.pp.filter_var(view, sub)

def test_filter_var_with_varm_varp(self, pbmc3k_processed):
A = pbmc3k_processed[:, :500].copy()
B = pbmc3k_processed[:, 500:].copy()
np.random.seed(42)
A_var_sel = np.random.choice(np.array([0, 1]), size=A.n_vars, replace=True)
B_var_sel = np.random.choice(np.array([0, 1]), size=B.n_vars, replace=True)
A.var["sel"] = A_var_sel
B.var["sel"] = B_var_sel
A_subset = A[:, A_var_sel == 1].copy()
B_subset = B[:, B_var_sel == 1].copy()
mdata = mu.MuData({"A": A, "B": B})
mdata.pull_var("sel")
mu.pp.filter_var(mdata, "sel", lambda y: y == 1)
assert mdata.shape[1] == int(np.sum(A_var_sel) + np.sum(B_var_sel))
assert_equal(mdata["A"], A_subset)
assert_equal(mdata["B"], B_subset)


@pytest.mark.usefixtures("filepath_h5mu")
class TestIntersectObs:
Expand Down

0 comments on commit c7461aa

Please sign in to comment.