Novartis / scar

scAR (single-cell Ambient Remover) is a deep learning model for removal of the ambient signals in droplet-based single cell omics
https://scar-tutorials.readthedocs.io/en/main/
50 stars 5 forks source link

Run time on GPU seems to be throttled by CPU #81

Closed LucHendriks closed 3 months ago

LucHendriks commented 3 months ago

Intro

Dataset

I have been doing some testing with the scAR package to perform denoising of my CITEseq protein counts. So far I have managed to get the package up and running and get the ambient profile, but I have been running into high run times for the recommended/default parameters when training a model. I am planning to run this per sample on the raw output of cellranger multi. Below are the dimensions of the object that I am testing with currently.

View of AnnData object with n_obs × n_vars = 733759 × 129
    var: 'gene_symbol', 'feature_types', 'genome'
    uns: 'feature_reference'

Machines

Since scAR supports to train on the GPU I opted for this to get the quickest runtimes. However, I have noticed that the GPU is not fully utilized and I think the CPU might be throttling the GPU at this point. Important to note is that in my situation I only have access to some preconfigured machines to run this on a Kubernetes cluster. Some info on the machines is given below, but feel free to reach out if you need more specific info.

GPU machine:

CPU machine:

On both I am running the same docker image using Ubuntu 22.04.4 LTS.

Tests

The code that I currently run is based on the tutorial in the docs where raw_adata is the object displayed above and scar_ambient_profile_df is a dataframe containing the ambient profiles I calculated per sample earlier using the setup_anndata() function.

ADT_scar = model(
    raw_count = raw_adata,
    ambient_profile = scar_ambient_profile_df[[sample_id]],  # Providing ambient profile is recommended for CITEseq; in other modes, you can leave this argument as the default value -- None
    feature_type = 'ADT', # "ADT" or "ADTs" for denoising protein counts in CITE-seq
    count_model = 'binomial',   # Depending on your data's sparsity, you can choose between 'binomial', 'possion', and 'zeroinflatedpossion'
)

ADT_scar.train(
    epochs=30,
    batch_size=64,
    verbose=True
)

# After training, we can infer the true protein signal
ADT_scar.inference()  # by defaut, batch_size=None, set a batch_size if getting GPU memory issue

GPU machine

device = 'cuda'

When running this the GPU usage goes up to 10-12% while one CPU core seems to spike to 100% and the RAM never gets higher than 30-40% usage.

Training: 100%|██████████| 30/30 [52:07<00:00, 104.26s/it, Loss=2.9530e+01]

device = 'cpu'

All 11 cores are being utilized at ~100% and memory (RAM) is in the same range of ~40% as when running with GPU on this machine.

Training:   7%|▋         | 2/30 [31:50<7:23:56, 951.31s/it, Loss=3.5956e+01]

CPU machine

device = 'cpu'

All 30 cores are being utilized at ~100% and memory (RAM) usage is the same as with the GPU machine.

Training: 100%|██████████| 30/30 [52:32<00:00, 105.07s/it, Loss=2.9492e+01]

Summary

I find that the GPU seems to not be fully optimized when available and I was wondering if you could maybe assist in helping me optimize the training of the model to reduce the runtime for the devices that I have available.

LucHendriks commented 3 months ago

Update on batch_size

I did some extra tests to see what the impact of the batch_size parameter was on the run time. When I increase the batch size I do get faster run times in both cases, but in the end the CPU starts to outperform the GPU even though the GPU usage still does not rise above 10% and seems to average out around 8% (out of 24GB). I ran some tests for a range of batch sizes and the results are shown below. For these tests I took the batch size value and train the model for 3 epochs and took the average time per iteration.

Is this behavior expected? Is there another way that I could improve the usage of the GPU, because it seems like there is some other bottleneck that is holding the GPU back.

GPU machine

device = 'cuda'

batch_size  it_average (s)
64      108.564994
512     46.143404
1024        42.911732
4096        44.127388
65536       42.423678
98304       42.044423

CPU machine

device = 'cpu'

batch_size  it_average (s)
64      106.024418
512     36.468826
1024        30.482797
4096        25.882069
65536       23.853371
98304       23.121426
CaibinSh commented 3 months ago

Hi @LucHendriks , thanks for your interest.

scAR was initially designed for single-batch experiments, typically with fewer than 10,000 cells, so GPU usage was not optimized for large-scale datasets.

The good news is that in version 0.7.0, we have incorporated a conditional variational autoencoder to ensure batch denoising, and GPU usage has been optimized for experiments with a large number of cells. We plan to release this update this week. However, if you're interested, you can try the latest version on our development branch (we've tested it and are currently adding a tutorial).

git clone https://github.com/Novartis/scar.git
cd scar 
git checkout develop  
pip install .  

In the model, you can set the cache_capacity parameter to 800,000 in your case. You can refer to the documentation here: https://scar-tutorials.readthedocs.io/en/develop/usages/training.html#scar.main._scar.model image

Additionally, I noticed that you have 733,759 cells. Are they the unfiltered counts from a single experiment, or are they from multiple batches of single-cell experiments?"

LucHendriks commented 3 months ago

Hello @CaibinSh, thank you very much for the swift reply!

I see that I have phrased it a bit confusing, but I have a big dataset with multiple samples and I am running scAR per sample separately. The object shape I shared is one of those samples (separate) and is indeed the unfiltered output of cellranger put into an AnnData object. So there are more AnnData objects like the one I shared above.

Below is the kneeplot from the 10x web summary for this one specific sample:

image

The samples I have all range from 100-800k in cells before filtering, this sample is on the top of the range. Is this something that you think scAR can handle right now, or should I start looking at different ways to correct my CITEseq data?

If you have any other recommendations to do when working with a dataset of this size I would also greatly appreciate it!

Thanks again for the help and development on this package!

CaibinSh commented 3 months ago

Hi @LucHendriks , thank you for the clear explanation.

We only use the raw_adata to estimate the ambient profile with the setup_anndata function, and the raw data isn't needed afterward. We run scAR using the filtered_adata, which is also an output of Cell Ranger. With the default settings, this process should take less than 10 minutes per sample with current version of scAR.

Please let me know if it does not work for you.

ADT_scar = model(
    raw_count = filtered_adata, # the output of cellranger, ...filtered_feature_bc_matrix.h5
    ambient_profile = scar_ambient_profile_df[[sample_id]],  # Providing ambient profile is recommended for CITEseq; in other modes, you can leave this argument as the default value -- None
    feature_type = 'ADT', # "ADT" or "ADTs" for denoising protein counts in CITE-seq
    count_model = 'binomial',   # Depending on your data's sparsity, you can choose between 'binomial', 'possion', and 'zeroinflatedpossion'
)

ADT_scar.train(
    epochs=200,
    batch_size=64,
    verbose=True
)

# After training, we can infer the true protein signal
ADT_scar.inference()  # by defaut, batch_size=None, set a batch_size if getting GPU memory issue
CaibinSh commented 3 months ago

Hi @LucHendriks , we have released a new version 0.7.0, which should have better performance in terms of GPU usage. You can install it through conda install bioconda::scar

LucHendriks commented 3 months ago

That has been a major oversight on my side, apologies for the confusion. Thanks again for the quick follow-up during the past days! I will try out the new update, but I don't expect any further issues.

I might have some questions in the future regarding optimization of the ambient profile and type of count_model used, but it might be best for me to open a new issue for that.