scverse / scvi-tools

Deep probabilistic analysis of single-cell and spatial omics data
http://scvi-tools.org/
BSD 3-Clause "New" or "Revised" License
1.23k stars 349 forks source link

Multi GPU training #1226

Closed adamgayoso closed 1 year ago

adamgayoso commented 3 years ago
  1. Write a custom DistributedSampler that also takes as input the overall set of indices to pull data from (i.e., train set or test set or val set indices). Probably just need to add a few lines of code to init, call super init and the write a custom iter method
  2. Use this sampler if multi gpu training selected (these are all through kwargs of pytorch lightning trainer)
vitkl commented 2 years ago

I would be also very interested in multi-GPU training of pyro models, specifically full data training mode where large data is split between different GPUs.

vitkl commented 2 years ago

It is actually fairly straightforward to do data parallelism in pyro using horovod (https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_horovod.py). This would mainly require 1) a new training plan (to use DistributedSampler, a different optimizer) and 2) a different device-backed data loader (to split data across devices).

There are issues with using this for models with both global and local cell-specific parameters (all parameters live on all devices).

adamgayoso commented 2 years ago

It is actually fairly straightforward to do data parallelism in pyro using horovod (https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_horovod.py). This would mainly require 1) a new training plan (to use DistributedSampler, a different optimizer) and 2) a different device-backed data loader (to split data across devices).

It is likely less straightforward to make this work with PyTorch Lightning and might require substantial work to make this work generally for Pyro. In particular we'd have to look more into how device backed data loaders would work in this case.

vitkl commented 2 years ago

I see. I also got some feedback from @fritzo that for models with local parameters this might not give as much space increase as one would hope because cell-specific parameters are quite large for just a few 100k cells (but this could give 4-5x more space). I don't quite understand PyTorch Lightning so if you solve this - I would be very keen to try.

vitkl commented 2 years ago

Does numpyro+jax more natively support multi-GPU training? If yes this could be a way to go.

What I am specifically interested in is data and model parameter parallelism where the data and model parameters for different cells (denoted by a plate) are distributed to different GPU devices. Maybe this is also possible with pyro.

Also cc @fehiepsi @fritzo @martinjankowiak

fritzo commented 2 years ago

[As metntioned above] Pyro can use Horovod for data parallelism across GPUs and machines in a cluster, but I believe parameters would be replicated on all nodes. NumPyro might be the way to go. @fehiepsi?

fehiepsi commented 2 years ago

Current NumPyro SVI does not support that pattern but it might be able to do using JAX. Something like

def loss_fn(batch, params):
    global_params, local_params = params
    model_g = handlers.substitute(model, data=global_params)
    guide_g = handlers.substitute(guide, data=global_params)

    def get_loss_local(data, local_params):
        model_l = handlers.substitute(model, data=local_params)
        guide_l = handlers.substitute(guide, data=local_params)
        loss = TraceELBO(model_l, guide_l, ...)
        return loss

   return jax.pmap(get_loss_local)(batch, params)

# then use jaxopt to optimize loss_fn over params: https://jaxopt.github.io/stable/stochastic.html#optax-solvers

though still seems to be a bit tricky to cover many usage cases (like when there are both global variables and local variables, we need to apply reduced sum at local variables).

vitkl commented 2 years ago

Thanks for your thoughts!

My models always have both local and global variables. Do you see any way to define device split along the pyro plate? Maybe that could be provided as option in numpyro?

On Wed, 13 Apr 2022, 12:02 Du Phan, @.***> wrote:

Current NumPyro SVI does not support that pattern but it might be able to do using JAX. Something like

def loss_fn(batch, params): global_params, local_params = params model_g = handlers.substitute(model, data=global_params) guide_g = handlers.substitute(guide, data=global_params)

def get_loss_local(data, local_params):
    model_l = handlers.substitute(model, data=local_params)
    guide_l = handlers.substitute(guide, data=local_params)
    loss = TraceELBO(model_l, guide_l, ...)
    return loss

return jax.pmap(get_loss_local)(batch, params)

then use jaxopt to optimize loss_fn over params: https://jaxopt.github.io/stable/stochastic.html#optax-solvers

though still seems to be a bit tricky to cover many usage cases (like when there are both global variables and local variables, we need to apply reduced sum at local variables).

— Reply to this email directly, view it on GitHub https://github.com/scverse/scvi-tools/issues/1226#issuecomment-1097916251, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFMFTV4CKS4EOBEXVKDTLYTVE2SURANCNFSM5F5UWUGQ . You are receiving this because you commented.Message ID: @.***>

vitkl commented 2 years ago

Does pure data parallelism work with current numpyro and scvi-tools? (Loading different minibatches of data to different devices, 8 minibatches in parallel but different cells in each training iteration)

mxposed commented 2 years ago

@adamgayoso sorry, do you have a recipe if I want to enable multi-GPU training of scVI model before it is released scvi-tools? I haven't done multi-GPU training before, so I'm asking where to start. Can I just apply a patch from #1357 ?

vitkl commented 1 year ago

@adamgayoso If you ignore device-backed data loaders for now, what is the main roadblock to implementing the Pyro+horovod solution? https://pyro.ai/examples/svi_horovod.html

Does this boil down to implementing an equivalent to torch.utils.data.distributed.DistributedSampler and modifying the training plan to add horovod use? Or is there more to it?

Is the problem in writing a general solution that works for any model OR is the problem that this won't work for any model?

Also cc @macwiatrak for discussion

adamgayoso commented 1 year ago

I expect that we only have minimal issues with non-Pyro models. This is due to updates in lightning that automatically wrap custom dataloader samplers like we have.

In the case of Pyro, we have a somewhat hacky solution that fuses it with a lightning module. I expect significantly more engineering work to get this right. A hacky solution might be quicker, but we shouldn't include that in this library.

adamgayoso commented 1 year ago

To clarify, lightning should handle:

  1. Automatically creating the distributed data loader (recent updates should allow this to work with no code changes on our side)
  2. Broadcasting the params and optimizers across devices

But this is in the default pytorch case. For Pyro, which lazily initializes params, the hacky solution would involve a callback that does some of the things you see in the linked pyro tutorial.

fritzo commented 1 year ago

@adamgayoso I'd love to make this easier to do in Pyro (as @vitkl has requested). What's your timeline? Could we sync the week of the Jan 23 to figure out what would be needed on the Pyro side?

adamgayoso commented 1 year ago

We don't have bandwidth to contribute much at the moment, but can review code. I think it's relatively straightforward to make this work in the nn.Module/PyroModule paradigm by altering what we call a TrainingPlan to use vanilla torch optimizers instead of Pyro optimizers. This will allow lightning to do almost all the work.

In other words, we can create a LowerLevelPyroTrainingPlan, using this lower level pattern internally.

loss_fn = lambda model, guide: pyro.infer.Trace_ELBO().differentiable_loss(model, guide, X_train, y_train)
with pyro.poutine.trace(param_only=True) as param_capture:
    loss = loss_fn(model, guide)
params = set(site["value"].unconstrained()
                for site in param_capture.trace.nodes.values())
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.90, 0.999))

Then lightning will do everything it needs to do with handling backprop.

We already have a callback to run the param initialization, which can then be changed to reset the optimizer

adamgayoso commented 1 year ago

@fritzo what is ELBOModule? This seems useful. I can maybe put up a draft PR with my idea

fritzo commented 1 year ago

@eb8680 can tell you more about ELBOModule

adamgayoso commented 1 year ago

It seems @vitkl has made it work via the new lower level trainingplan https://github.com/scverse/scvi-tools/pull/1845

vitkl commented 1 year ago

It appears that the lightning "ddp_notebook" strategy + https://github.com/scverse/scvi-tools/pull/1845 doesn't allow saving the trained model because lightning fails to load checkpoints saved by the worker GPU. Looks like when the state_dict is loaded the model parameters don't exist. I tried creating a callback that would do a forward pass through the model on_fit_end, on_train_end and on_load_checkpoint - however, none of that changed anything. Maybe this means that the strategy or the training plan needs to be modified but I don't understand what can be done next. Any tips would be appreciated @fritzo @eb8680 @adamgayoso

I have not tested the standard "ddp" strategy in a script yet.

```python File /nfs/team205/vk7/sanger_projects/my_packages/cell2state/cell2state/models/_programme_model.py:839, in ProgrammeModel.train(self, max_epochs, use_gpu, train_size, validation_size, batch_size, gene_batch_size, drop_last, early_stopping, lr, num_particles, scale_elbo, training_plan, plan_kwargs, dl_kwargs, simple_progress_bar, accumulate_grad_batches, accelerator, **trainer_kwargs) 828 if isinstance(use_gpu, list): 829 runner = TrainRunnerLowLevel( 830 self, 831 training_plan=training_plan, (...) 837 **trainer_kwargs, 838 ) --> 839 return runner() 840 else: 841 runner = TrainRunner( 842 self, 843 training_plan=training_plan, (...) 847 **trainer_kwargs, 848 ) File /nfs/team205/vk7/sanger_projects/my_packages/cell2state/cell2state/models/base/trainrunner.py:81, in TrainRunnerLowLevel.__call__(self) 78 if hasattr(self.data_splitter, "n_val"): 79 self.training_plan.n_obs_validation = self.data_splitter.n_val ---> 81 self.trainer.fit(self.training_plan, self.data_splitter) 82 self._update_history() 84 # data splitter only gets these attrs after fit File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/scvi/train/_trainer.py:187, in Trainer.fit(self, *args, **kwargs) 181 if isinstance(args[0], PyroTrainingPlan): 182 warnings.filterwarnings( 183 action="ignore", 184 category=UserWarning, 185 message="`LightningModule.configure_optimizers` returned `None`", 186 ) --> 187 super().fit(*args, **kwargs) File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:603, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 601 raise TypeError(f"`Trainer.fit()` requires a `LightningModule`, got: {model.__class__.__qualname__}") 602 self.strategy._lightning_module = model --> 603 call._call_and_handle_interrupt( 604 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 605 ) File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:36, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs) 34 try: 35 if trainer.strategy.launcher is not None: ---> 36 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs) 37 else: 38 return trainer_fn(*args, **kwargs) File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py:123, in _MultiProcessingLauncher.launch(self, function, trainer, *args, **kwargs) 120 if trainer is None: 121 return worker_output --> 123 self._recover_results_in_main_process(worker_output, trainer) 124 return worker_output.trainer_results File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers/multiprocessing.py:156, in _MultiProcessingLauncher._recover_results_in_main_process(self, worker_output, trainer) 154 if worker_output.weights_path is not None: 155 ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path) --> 156 trainer.lightning_module.load_state_dict(ckpt) 157 self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path) 159 trainer.state = worker_output.trainer_state File /nfs/team283/vk7/software/miniconda3farm5/envs/horovod_scvi19_cuda113/lib/python3.9/site-packages/torch/nn/modules/module.py:1497, in Module.load_state_dict(self, state_dict, strict) 1492 error_msgs.insert( 1493 0, 'Missing key(s) in state_dict: {}. '.format( 1494 ', '.join('"{}"'.format(k) for k in missing_keys))) 1496 if len(error_msgs) > 0: -> 1497 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( 1498 self.__class__.__name__, "\n\t".join(error_msgs))) 1499 return _IncompatibleKeys(missing_keys, unexpected_keys) RuntimeError: Error(s) in loading state_dict for LowLevelPyroTrainingPlan: Unexpected key(s) in state_dict: "module._model.accessibility_bias_mean_prior", ... "module._guide.scales.weights.tf_dna_binding_preference.weights.motif_weight_unconstrained" ... ``` <\details>
vitkl commented 1 year ago

@martinkim0 Nice work! Great to have this supported.

Is it possible to combine this approach with DeviceBackedDataSplitter? I am interested in loading data subsets once - simply distributing different cells or spatial locations across GPUs.