You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Is your feature request related to a problem? Please describe.
It is not currently straightforward to pass external dataloaders to train a model. In particular, loading torch.Tensor data and directly feeding it to a model as input doesn't seem possible because scvi.data._utils._check_nonnegative_integers does not handle torch.Tensor.
It would be very useful to be able to feed a custom dataloader, dictionary or AnnData as direct input to model.train() without having to copy torch.Tensor back to numpy or pandas. Maybe this can be implemented using model.train(data_module=data_module) ?
Describe the solution you'd like
importtorchimportscanpyasscimportscvicounts=torch.randint(0,10,(500, 10))
adata=sc.AnnData(scipy.sparse.csr_matrix(counts.shape), #AnnData does not allow torch.Tensor in .X fieldlayers={'counts':counts})
scvi.model.SCVI.setup_anndata(adata,layer="counts")
model=scvi.model.SCVI(adata)
model.train()
The text was updated successfully, but these errors were encountered:
We cover the enhancement to use custom dataloader in the recent version of scVI-tools.
However, it is not clear yet which minimal checks (integer, gene names) we still want to perform.
About your example: @Intron7: Is this idea of having AnnData in torch recommended? What analysis capabilities are possible in this scenario? I thought this is meant to be done in rapids_singlecell. Does rapids copy back and forth between CPU and GPU or is the full data kept between processing steps on GPU?
We are still talking about how this would work. However at the moment whenever I use rsc I have to transform back to cpu and than use scvi. Rapids-singlecell really wants .X and .layers on the GPU so everything has to be in memory. I would really like if we used DLPack for this. DLPack allows for the 0 copy conversion from cupy and jax to torch.
Hi @j-bac, thanks for the suggestion. We will be releasing a tutorial with our next release (v1.2) that covers a basic usecase with a custom dataloader. I'll note that we currently don't support inference methods yet (e.g.get_latent_representation), but it's something we're working on.
Is your feature request related to a problem? Please describe.
It is not currently straightforward to pass external dataloaders to train a model. In particular, loading
torch.Tensor
data and directly feeding it to a model as input doesn't seem possible becausescvi.data._utils._check_nonnegative_integers
does not handletorch.Tensor
.It would be very useful to be able to feed a custom dataloader, dictionary or AnnData as direct input to model.train() without having to copy
torch.Tensor
back to numpy or pandas. Maybe this can be implemented usingmodel.train(data_module=data_module)
?Describe the solution you'd like
The text was updated successfully, but these errors were encountered: