cellarium-ai / cellarium-ml

Distributed single-cell data analysis.
BSD 3-Clause "New" or "Revised" License
9 stars 2 forks source link

scvi: include the option to limit to a specific set of genes #204

Open sjfleming opened 3 weeks ago

sjfleming commented 3 weeks ago

We want a way to on-the-fly subset to a specific list of genes.

We do not want the dataloader to load the entire set of genes because this would be a big waste of memory / cuda memory.

Should this be a transform? It seems like a transform happens too late... I think we'd really like this to be part of the dataloader itself. Any thoughts @ordabayevy ?

What if we just wrote a new convert_fn for X to be used here in the config file?

data:
  dadc:
    class_path: cellarium.ml.data.DistributedAnnDataCollection
    init_args:
      ...
  batch_keys:
    x_ng:
      attr: X
      convert_fn: cellarium.ml.utilities.data.subset_genes_and_densify
ordabayevy commented 3 weeks ago

There is a Filter transform but like you say it might be too late if you worry about memory/speed. Yes, then the other place to add filtering before data got moved to gpu would be convert_fn. For that we probably need to figure out how to provide the list of subset genes.

sjfleming commented 3 weeks ago

@ordabayevy These two commits provide a (working) rough idea. But I don't know if the changes to core data.py are clean enough. It could be improved.

https://github.com/cellarium-ai/cellarium-ml/commit/e21e7d9ddb2abe1055199c43acc0848168b144c7 https://github.com/cellarium-ai/cellarium-ml/commit/a7ed6c78cec520f03ed0bc7a0f600109ea3ecda9

In particular this config file https://github.com/cellarium-ai/cellarium-ml/commit/a7ed6c78cec520f03ed0bc7a0f600109ea3ecda9

is ugly because of the repetition that would (seemingly?) be needed for X and var_names_g. Hopefully we can find a more elegant way to handle it. The referenced csv file just contains a list of gene names delimited by \n with no header.

ordabayevy commented 3 weeks ago

I think adding convert_fn_kwargs is a nice extension! But I agree that additional logic in data.py is not very elegant. It almost feels like there might be a need for additional transforms before the data is moved to device 🤔 . Here are multiple ideas:

  1. How often will the list of subset genes change? If it doesn't change much will it be easier to create a new extract files with filtered genes?
  2. How much of a CUDA memory overhead is there if we load everything to GPU and then use the Filter transform? It should be a constant overhead and if the batch size not too big it might be unnoticeable.
  3. There is a on_before_batch_transfer hook that we can add to CellariumModule and use it somehow. For example (just an idea), maybe we can mark some transforms (Filter in this case) to act inside the on_before_batch_transfer hook on CPU.
sjfleming commented 3 weeks ago
  1. The gene list could change with each run, and we would want to be quite flexible here.
  2. I am not sure. It might not really be that bad to just use a Filter transform. I haven't tried.
  3. Ooooh interesting idea, I like the sound of it. (It might even be the case that on-CPU stuff like computing embeddings in a GPT model -- if you wanted to -- could fit into this framework...)
sjfleming commented 2 weeks ago

@ImXman todo:

  1. try the filter transform approach (https://cellarium-ai.github.io/cellarium-ml/transforms.html#cellarium.ml.transforms.Filter) with a list of gene names (or ensembl ids) coming from a csv file, kind of like this https://github.com/cellarium-ai/cellarium-ml/blob/a7ed6c78cec520f03ed0bc7a0f600109ea3ecda9/example_configs/scvi_lake_kidney_config.yaml#L154-L158
    • main thing: see if this is slow due to loading the 36k genes into cuda and then subsetting after... it might not actually be slow
  2. Try option (3) from Yerdos above. This would require some serious thinking to design it the right way. We would want to add an on_before_batch_transfer() method here (https://github.com/cellarium-ai/cellarium-ml/blob/main/cellarium/ml/core/module.py) and somehow we would need to be able to tell (based on the config file, and also for instantiation of CellariumPipeline in general -- https://cellarium-ai.github.io/cellarium-ml/core.html#cellarium.ml.core.CellariumPipeline) which transforms occur before transfer to GPU and which occur after. This will require some careful though and some additional code.