diff --git a/muon/_core/preproc.py b/muon/_core/preproc.py index 04c0735..a24b788 100644 --- a/muon/_core/preproc.py +++ b/muon/_core/preproc.py @@ -739,6 +739,7 @@ def func(x): # will fail due to _validate_value() attrm = dict(attrm) attrp = dict(attrp) + layers = dict(data.layers) # Subset .obs/.var setattr(data, f"_{attr}", df[subset]) @@ -773,11 +774,12 @@ def func(x): data.filename = None # Subset layers - for layer in data.layers: + for layer in layers: if attr == "obs": - data.layers[layer] = data.layers[layer][subset, :] + layers[layer] = layers[layer][subset, :] else: - data.layers[layer] = data.layers[layer][:, subset] + layers[layer] = layers[layer][:, subset] + data.layers = layers # Subset raw - only when subsetting obs if attr == "obs" and data.raw is not None: diff --git a/tests/test_muon_preproc.py b/tests/test_muon_preproc.py index 8a4ed99..0133be6 100644 --- a/tests/test_muon_preproc.py +++ b/tests/test_muon_preproc.py @@ -97,6 +97,45 @@ def test_filter_obs_with_obsm_obsp(self, pbmc3k_processed): assert_equal(mdata["A"], A_subset) assert_equal(mdata["B"], B_subset) + def test_filter_obs_with_obsm_obsp_explicit(self, mdata): + mdata = mdata.copy() + + # obsm + np.random.seed(42) + mdata["mod1"].obsm["X_normal"] = np.random.normal(size=(mdata["mod1"].n_obs, 10)) + mdata["mod2"].obsm["X_normal"] = np.random.normal(size=(mdata["mod2"].n_obs, 10)) + mdata.obsm["X_normal"] = np.random.normal(size=(mdata.n_obs, 10)) + selection = mdata.obsm["X_normal"].sum(axis=1) > 0 + + # obsp + mdata["mod1"].obsp["connectivities"] = np.random.normal( + size=(mdata["mod1"].n_obs, mdata["mod1"].n_obs) + ) + mdata["mod2"].obsp["connectivities"] = np.random.normal( + size=(mdata["mod2"].n_obs, mdata["mod2"].n_obs) + ) + mdata.obsp["connectivities"] = np.random.normal(size=(mdata.n_obs, mdata.n_obs)) + + mu.pp.filter_obs(mdata, selection) + assert mdata.n_obs == selection.sum() + + def test_filter_obs_anndata(self, mdata): + adata = mdata["mod1"].copy() + + # layers + adata.layers["X2"] = adata.X**2 + + # obsm + np.random.seed(42) + adata.obsm["X_normal"] = np.random.normal(size=(adata.n_obs, 10)) + selection = adata.obsm["X_normal"].sum(axis=1) > 0 + + # obsp + adata.obsp["connectivities"] = np.random.normal(size=(adata.n_obs, adata.n_obs)) + + mu.pp.filter_obs(adata, selection) + assert adata.n_obs == selection.sum() + # Variables def test_filter_var_adata(self, mdata, filepath_h5mu):