Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: replace scanpy.neighbors._compute_connectivities_umap with `sc… #129

Merged
4 commits merged into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions muon/_core/preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from anndata import AnnData
from scanpy import logging
from scanpy.tools._utils import _choose_representation
from scanpy.neighbors import _compute_connectivities_umap
from scanpy.neighbors._connectivity import umap
from umap.distances import euclidean
from umap.sparse import sparse_euclidean, sparse_jaccard
from umap.umap_ import nearest_neighbors
Expand Down Expand Up @@ -162,7 +162,7 @@ def _make_slice_intervals(idx, maxsize=10000):
def _l2norm(
adata: AnnData, rep: Optional[Union[Iterable[str], str]] = None, n_pcs: Optional[int] = 0
):
X = _choose_representation(adata, rep, n_pcs)
X = _choose_representation(adata=adata, use_rep=rep, n_pcs=n_pcs)
sparse_X = issparse(X)
if sparse_X:
X_norm = linalg.norm(X, ord=2, axis=1)
Expand Down Expand Up @@ -211,7 +211,7 @@ def l2norm(
rep = next(it)
try:
next(it)
except StopIteration as e:
except StopIteration:
pass
else:
raise RuntimeError("If 'rep' is an Iterable, it must have length 1")
Expand All @@ -220,7 +220,7 @@ def l2norm(
n_pcs = next(it)
try:
next(it)
except StopIteration as e:
except StopIteration:
pass
else:
raise RuntimeError("If 'n_pcs' is an Iterable, it must have length 1")
Expand Down Expand Up @@ -358,7 +358,7 @@ def neighbors(
mod_neighbors[i] = nparams["params"].get("n_neighbors", 0)

neighbors_params[mod] = nparams
reps[mod] = _choose_representation(mdata.mod[mod], use_rep, n_pcs)
reps[mod] = _choose_representation(adata=mdata.mod[mod], use_rep=use_rep, n_pcs=n_pcs)
mod_reps[mod] = (
use_rep if use_rep is not None else -1
) # otherwise this is not saved to h5mu
Expand Down Expand Up @@ -585,7 +585,7 @@ def neighdist(cell, nz):
neighbordistances = _sparse_csr_fast_knn(neighbordistances, n_neighbors + 1)

logging.info("Calculating connectivities...")
_, connectivities = _compute_connectivities_umap(
connectivities = umap(
knn_indices=neighbordistances.indices.reshape(
(neighbordistances.shape[0], n_neighbors + 1)
),
Expand All @@ -599,8 +599,8 @@ def neighdist(cell, nz):
conns_key = "connectivities"
dists_key = "distances"
else:
conns_key = key_added + "_connectivities"
dists_key = key_added + "_distances"
conns_key = f"{key_added}_connectivities"
dists_key = f"{key_added}_distances"
neighbors_dict = {"connectivities_key": conns_key, "distances_key": dists_key}
neighbors_dict["params"] = {
"n_neighbors": n_neighbors,
Expand Down Expand Up @@ -711,7 +711,7 @@ def func(x):
else:
obs_subset = data.obs_names.isin(var)
else:
raise ValueError(f"When providing obs_names directly, func has to be None.")
raise ValueError("When providing obs_names directly, func has to be None.")

# Subset .obs
data._obs = data.obs[obs_subset]
Expand Down Expand Up @@ -819,12 +819,9 @@ def func(x):
)
else:
if func is None:
if np.array(var).dtype == bool:
var_subset = var
else:
var_subset = data.var_names.isin(var)
var_subset = var if np.array(var).dtype == bool else data.var_names.isin(var)
else:
raise ValueError(f"When providing var_names directly, func has to be None.")
raise ValueError("When providing var_names directly, func has to be None.")

# Subset .var
data._var = data.var[var_subset]
Expand Down
22 changes: 10 additions & 12 deletions muon/_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from scanpy import logging
from scanpy.tools._utils import _choose_representation
from scanpy.neighbors import _compute_connectivities_umap

# from scanpy.neighbors import _compute_connectivities_umap

from typing import Union, Optional, List, Iterable, Mapping, Sequence, Type, Any, Dict, Literal
from types import MappingProxyType
Expand Down Expand Up @@ -435,11 +436,9 @@ def mofa(
if outfile is None:
outfile = os.path.join("/tmp", "mofa_{}.hdf5".format(strftime("%Y%m%d-%H%M%S")))

if use_var:
if use_var not in data.var.columns:
warn(f"There is no column {use_var} in the provided object")
use_var = None

if use_var and use_var not in data.var.columns:
warn(f"There is no column {use_var} in the provided object")
use_var = None
if isinstance(data, MuData):
common_obs = reduce(np.intersect1d, [v.obs_names.values for k, v in mdata.mod.items()])
if len(common_obs) != mdata.n_obs:
Expand All @@ -457,9 +456,8 @@ def mofa(
ent = entry_point()

lik = likelihoods
if lik is not None:
if isinstance(lik, str) and isinstance(lik, Iterable):
lik = [lik for _ in range(len(mdata.mod))]
if lik is not None and (isinstance(lik, str) and isinstance(lik, Iterable)):
lik = [lik for _ in range(len(mdata.mod))]

ent.set_data_options(
scale_views=scale_views,
Expand Down Expand Up @@ -787,7 +785,7 @@ def snf(
mod_neighbors[i] = nparams["params"].get("n_neighbors", 0)

neighbors_params[mod] = nparams
reps[mod] = _choose_representation(mdata.mod[mod], use_rep, n_pcs)
reps[mod] = _choose_representation(adata=mdata.mod[mod], use_rep=use_rep, n_pcs=n_pcs)
mod_reps[mod] = (
use_rep if use_rep is not None else -1
) # otherwise this is not saved to h5mu
Expand Down Expand Up @@ -855,7 +853,7 @@ def _normalize(x):
def _dominateset(x, k=20):
def _zero(arr):
if k >= len(arr):
raise ValueError(f"'n_neighbors' seems to be too high.")
raise ValueError("'n_neighbors' seems to be too high.")
arr = arr.copy()
arr[np.argsort(arr)[: (len(arr) - k)]] = 0
return arr
Expand Down Expand Up @@ -1319,7 +1317,7 @@ def umap(
n_pcs = {k: (v if v != -1 else None) for k, v in nparams["n_pcs"].items()}
observations = mdata.obs.index
for mod, rep in use_rep.items():
rep = _choose_representation(mdata.mod[mod], rep, n_pcs[mod])
rep = _choose_representation(adata=mdata.mod[mod], use_rep=rep, n_pcs=n_pcs[mod])
nfeatures += rep.shape[1]
reps[mod] = rep
rep = np.empty((len(observations), nfeatures), np.float32)
Expand Down
Loading