scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.26k stars 357 forks source link

How to increase num_workers in pytorch DataLoader? #2933

Closed winglet0996 closed 3 months ago

winglet0996 commented 3 months ago

Hi, I am following the scRNA-seq tutorial.

After running

model = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood="nb")
model.train()

I got

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/winglet/Data/apps/micromamba/envs/scanpy/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=87` in the `DataLoader` to improve performance.

When I set this at the very beginning

scvi.settings.dl_num_workers = 87

I got tons of

/home/winglet/Data/apps/micromamba/envs/scanpy/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = os.fork()

and the training speed was much slower than the previous one (time/iter increaesd 5x).

Versions:

scvi-tools-1.1.5

canergen commented 3 months ago

Hi, indeed we usually see a slowdown when using multiple workers. In the next release (or current main branch), we will add a persistent worker flag. We haven’t benchmarked in depth how the speed compares with this setting (if you try it out please report it). For the JAX warning, you can ignore it. JAX is not called during training so it shouldn’t increase your runtime. To sum it up, PyTorch Lightning has some ideas to maximize performance that don’t necessarily improve performance.