Open gargerd opened 2 months ago
I think figured out the speed problem: I originally input my counts data as a sparse csr matrix, and scPoli converts sparse matrices in the dataloader function like this:
if self._is_sparse:
x = torch.tensor(np.squeeze(self.data[index].toarray()), dtype=torch.float32)
else:
x = self.data[index]
So I tried to input the raw count matrix as an np.array instead of a csr array, like this:
adata_gse_.X=adata_gse_.layers['raw_counts'].A
and this sped up the training with raw counts.
I think figured out the speed problem: I originally input my counts data as a sparse csr matrix, and scPoli converts sparse matrices in the dataloader function like this:
if self._is_sparse: x = torch.tensor(np.squeeze(self.data[index].toarray()), dtype=torch.float32) else: x = self.data[index]
So I tried to input the raw count matrix as an np.array instead of a csr array, like this:
adata_gse_.X=adata_gse_.layers['raw_counts'].A
and this sped up the training with raw counts.
I understand the reason why the training is slow. But curious, if the way raw counts are added affect the training and thereby query predictions in the end?
Dear scarches team,
I have noticed, that when I tried to run scPoli.train() on a reference dataset of ~40000 cells x 5300 genes, and set the count layer .X to the raw counts, the training got progressively slower with the increasing number of cells.
Part of my code running the training on the reference dataset:
Debugging steps
I looked into the source code of scPoli trainer.py, and to understand where the training is getting slowed down, I printed:
Number of training iterations within a batch
Duration of a training iteration within a batch
Full data with 'zinb' as recon loss + raw counts as .X
For the dataset of ~40 000 x 5300, I got 284 iterations/batch and a 4.5-5 sec/iteration speed. This would mean that one epoch lasts around 22-24 minutes. (see screenshot)
Dropping number of cells to 10 000 with 'zinb' as recon loss + raw counts as .X
When only considering 10 000 cells and feeding a 10 000 X 5300 anndata to the model, there are less iterations/batch and the iterations are faster. I guess this is due to the fact that there are 4-times less cells, so the iterations are 4-times faster (4.5 sec vs. 1.1 sec) and the number of iterations/batch is also 4-times less (71 vs. 284). (see screenshot)
Full data with 'zinb' as recon loss + acosh normalised counts as .X
When feeding the original dataset (~40 000 x 5300) using an acosh normalised cell count as the .X layer + 'zinb' as recon loss, the training is much faster:
Checking memory usage
I also tried with a node of 160 GB os CPU RAM, but the training was still extremely slow.
Question
Is there a way to make the training faster or is scPoli generally that slow when it comes to larger reference datasets?
Just as a reference: For the merged reference + query dataset (~52 000 cells x 5300 genes), scVI needs around 12 minutes to train with 400 Epochs, while scANVI around 36 minutes with 400 Epochs.