laminlabs / lamindb

A data framework for biology.
https://docs.lamin.ai
Apache License 2.0
120 stars 9 forks source link

Custom dataloader for scvi-tools #1826

Open canergen opened 3 weeks ago

canergen commented 3 weeks ago

Hi, as discussed in our recent meeting, we can work with custom dataloaders within scvi-tools. It requires being a torch DataModule that loads dictionaries of keyword arguments for each mini batch. We have a working example for CELLXGENE census: https://github.com/chanzuckerberg/cellxgene-census/blob/ebezzi/census-scvi-datamodule/api/python/notebooks/experimental/pytorch_loader_scvi.ipynb (see imports to find the actual datamodule code). You will find that the notebook requires custom code to load and save models. We have built on top of this solution and changed the scvi-tools code to work without an registered AnnData object. I'm setting up a Colab notebook currently with our code and will update here. Next steps would be feedback to our setup_registry function that takes a datamodule and populates all fields within an scvi-tools model. We currently only support scvi and would like to support more models. I expect that this will be joint effort as things might break outside of scvi-tools (for sure things will not work for multimodal - totalVI or multiVI models).

canergen commented 3 weeks ago

@jkobject Concerning scDataloader, it is currently setup for BERT like loading using a max_len argument. I assume it loads for every cell the 1000 most expressed genes. This structure is not allowed for scvi-tools (we expect the same genes being represented for all cells). Can you provide some help about the how argument?

falexwolf commented 3 weeks ago

Thanks for opening the issue, here!

Can you confirm the objective? I understand that by leveraging MappedCollection (or its scdataloader wrapper), you'd like to be more flexible regarding the storage backend so that you don't necessarily rely on a gigantic tiledbsoma but could also train on a collection of h5ads or parquets.

If so, can you help clarify what exactly is needed from MappedCollection so that you could run scvi on it? Is below everything you want from us as an interface for MappedCollection?

It requires being a torch DataModule that loads dictionaries of keyword arguments for each mini batch.

Koncopd commented 3 weeks ago

@canergen if we integrate the underlying MappedCollection, that should not be a problem. MappedCollection can return all genes (outer or inner join).

jkobject commented 3 weeks ago

Hello All, I will get back to you in a week about this and potential updates to scDataLoader to make it fully in line with scvi-tools. I have already made a PR on the scverse community packages https://github.com/scverse/ecosystem-packages/pull/195 so maybe this discussion will find a better place in this PR?

scDataloader is setup to work in multiple ways to adapt to the new geneformer, scGPT and scPRINT models out there which don't always get the same set of genes. U However it also works if choosing always the same set of genes. how is a parameter for the Collator and is defined in its documentation here: https://github.com/jkobject/scDataLoader/blob/main/scdataloader/collator.py you can use "some" to only use a predefined set of genes, "most expr" to only use the most expr genes per cell, "random exp" to only use some randomly expr genes per cell.

If you know which genes you want to use you can use "some" and pass a list of genes, e.g.

from scdataloader import DataModule
import tqdm

datamodule = DataModule(
    collection_name=NAME,
    organisms=["NCBITaxon:9606"], #organism that we will work on
    how="some", # for the collator
    genelist=most_variable_genes_in_data, #not recommended (as it will highlight batch effects)
    batch_size=64,
    num_workers=1,
    validation_split=0.1,
    test_split=0)

 for i in tqdm.tqdm(datamodule.train_dataloader()):
    print(i)
    break

What I would like to know @canergen is if scdataloader is already acceptable in its current state or what specific parameters, default values and functions it is missing to make it worthy of scvi-tools.

canergen commented 2 weeks ago

Can you confirm the objective? I understand that by leveraging MappedCollection (or its scdataloader wrapper), you'd like to be more flexible regarding the storage backend so that you don't necessarily rely on a gigantic tiledbsoma but could also train on a collection of h5ads or parquets.

The current comparison is to a huge h5ad file, which doesn't scale to more than 100 million cells (and is already more expensive below these numbers). While we have support for CELLXGENE census we don't support all tiledbsoma databases and CxG census is restrictive in extending it to the whole scvi-tools library..

I assume both dataloaders are in line with our requirements (providing a pipeline of dictionary values). My take is that MappedCollection can provide a dictionary with categorical/integer/float values for every obs column in the underlying AnnData object. However, there are some details. Missing pieces: We won't provide any validation for custom dataloader (like is the data really count data, is the list of celltype labels or categories in line with the registry). This is not strictly required in our perspective but if you want this feature handling it in the underlying data format can be much better optimized than across formats. We require a dictionary of celltype labels/batches to integer categorical codes that needs to be stored with the dataloader. The scvi-tools model only take the integer values not the categories (setup_datamodule function within https://colab.research.google.com/drive/10sXec_TicMKtLA6hMcgfkado-FgoNKxw#scrollTo=jiksMKJdlGdK - it doesn't install correctly yet as I have issues with the typing of the CZI code). Is there support of MuData within MappedCollection. How would you support multimodal (ATAC+RNA+protein data). Currently, we are restricted in terms of input data to the CELLXGENE census with obviously no multimodal collections. I assume large-scale models in the coming year will contain multimodal data.

jkobject commented 1 week ago

I am not sure I understand everything, but you seem to say that scdataloader is already in line with scvi-tools, but we need to enforce that its other elements in the output dictionary, e.g., class / batch... are categoricals encoded with integers. The function is already there, and in my example and current use case, everything is transformed so that the encoding is available via the "decoders" property of the datamodule or the "encoder" property of the pytorch datasets.

Given how the mapped collection is implemented, multimodal data is quite straightforward. It just amounts to accessing another field of the underlying H5ad format. For the user, it would receive an 'X_rna', 'X_atac', ... in the output dictionary of tensors instead of just the 'X'.