theislab / scarches

Reference mapping for single-cell genomics
https://docs.scarches.org/en/latest/
BSD 3-Clause "New" or "Revised" License
335 stars 51 forks source link

scPoli training progressively slower with increasing cell number when using raw counts #250

Open gargerd opened 2 months ago

gargerd commented 2 months ago

Dear scarches team,

I have noticed, that when I tried to run scPoli.train() on a reference dataset of ~40000 cells x 5300 genes, and set the count layer .X to the raw counts, the training got progressively slower with the increasing number of cells.

Part of my code running the training on the reference dataset:

import scarches as sca

early_stopping_kwargs = {
    "early_stopping_metric": "val_prototype_loss",
    "mode": "min",
    "threshold": 0,
    "patience": 20,
    "reduce_lr": True,
    "lr_patience": 13,
    "lr_factor": 0.1,
}

## Subset to common genes between reference HVG and query genes
adata_gse_=adata_gse[:,common_genes].copy()

## Set .X to raw counts
adata_gse_.X=adata_gse_.layers['raw_counts']

scpoli_model=sca.models.scPoli(adata=adata_gse_,
                               condition_keys=['patient','condition','method'],
                               cell_type_keys='final_celltypes',
                               embedding_dims=5,
                               latent_dim=10,
                               recon_loss='zinb')

scpoli_model.train(n_epochs=400,
                 pretraining_epochs=340,
                  early_stopping_kwargs=early_stopping_kwargs,
                  eta=5)

Debugging steps

I looked into the source code of scPoli trainer.py, and to understand where the training is getting slowed down, I printed:

Full data with 'zinb' as recon loss + raw counts as .X

For the dataset of ~40 000 x 5300, I got 284 iterations/batch and a 4.5-5 sec/iteration speed. This would mean that one epoch lasts around 22-24 minutes. (see screenshot)

image

Dropping number of cells to 10 000 with 'zinb' as recon loss + raw counts as .X

When only considering 10 000 cells and feeding a 10 000 X 5300 anndata to the model, there are less iterations/batch and the iterations are faster. I guess this is due to the fact that there are 4-times less cells, so the iterations are 4-times faster (4.5 sec vs. 1.1 sec) and the number of iterations/batch is also 4-times less (71 vs. 284). (see screenshot)

image

Full data with 'zinb' as recon loss + acosh normalised counts as .X

When feeding the original dataset (~40 000 x 5300) using an acosh normalised cell count as the .X layer + 'zinb' as recon loss, the training is much faster:

image

Checking memory usage

procs -----------memory---------- ---swap-- -----io---- -system-- ------cpu-----
 r  b   swpd   free   buff  cache   si   so    bi    bo   in   cs us sy id wa st
 1  0      0 161233872   5244 286898176    0    0     1     5    0    0  3  1 96  0  0
 1  0      0 161233360   5244 286898176    0    0     0     0 2785 2087  3  0 97  0  0
 1  0      0 161234736   5244 286898176    0    0     0     0 2351 2096  3  0 97  0  0
 1  0      0 161235744   5244 286898176    0    0     0     0 2385 2236  3  0 97  0  0
 1  0      0 161236240   5244 286898176    0    0     0     0 2578 2205  3  0 97  0  0

I also tried with a node of 160 GB os CPU RAM, but the training was still extremely slow.

Question

Is there a way to make the training faster or is scPoli generally that slow when it comes to larger reference datasets?

Just as a reference: For the merged reference + query dataset (~52 000 cells x 5300 genes), scVI needs around 12 minutes to train with 400 Epochs, while scANVI around 36 minutes with 400 Epochs.

gargerd commented 1 month ago

I think figured out the speed problem: I originally input my counts data as a sparse csr matrix, and scPoli converts sparse matrices in the dataloader function like this:

if self._is_sparse:
     x = torch.tensor(np.squeeze(self.data[index].toarray()), dtype=torch.float32)
else:
     x = self.data[index]

So I tried to input the raw count matrix as an np.array instead of a csr array, like this:

adata_gse_.X=adata_gse_.layers['raw_counts'].A

and this sped up the training with raw counts.

yojetsharma commented 1 month ago

I think figured out the speed problem: I originally input my counts data as a sparse csr matrix, and scPoli converts sparse matrices in the dataloader function like this:

if self._is_sparse:
     x = torch.tensor(np.squeeze(self.data[index].toarray()), dtype=torch.float32)
else:
     x = self.data[index]

So I tried to input the raw count matrix as an np.array instead of a csr array, like this:

adata_gse_.X=adata_gse_.layers['raw_counts'].A

and this sped up the training with raw counts.

I understand the reason why the training is slow. But curious, if the way raw counts are added affect the training and thereby query predictions in the end?