Closed adamgayoso closed 1 year ago
I would be also very interested in multi-GPU training of pyro models, specifically full data training mode where large data is split between different GPUs.
It is actually fairly straightforward to do data parallelism in pyro using horovod (https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_horovod.py). This would mainly require 1) a new training plan (to use DistributedSampler, a different optimizer) and 2) a different device-backed data loader (to split data across devices).
There are issues with using this for models with both global and local cell-specific parameters (all parameters live on all devices).
It is actually fairly straightforward to do data parallelism in pyro using horovod (https://github.com/pyro-ppl/pyro/blob/dev/examples/svi_horovod.py). This would mainly require 1) a new training plan (to use DistributedSampler, a different optimizer) and 2) a different device-backed data loader (to split data across devices).
It is likely less straightforward to make this work with PyTorch Lightning and might require substantial work to make this work generally for Pyro. In particular we'd have to look more into how device backed data loaders would work in this case.
I see. I also got some feedback from @fritzo that for models with local parameters this might not give as much space increase as one would hope because cell-specific parameters are quite large for just a few 100k cells (but this could give 4-5x more space). I don't quite understand PyTorch Lightning so if you solve this - I would be very keen to try.
Does numpyro+jax more natively support multi-GPU training? If yes this could be a way to go.
What I am specifically interested in is data and model parameter parallelism where the data and model parameters for different cells (denoted by a plate) are distributed to different GPU devices. Maybe this is also possible with pyro.
Also cc @fehiepsi @fritzo @martinjankowiak
[As metntioned above] Pyro can use Horovod for data parallelism across GPUs and machines in a cluster, but I believe parameters would be replicated on all nodes. NumPyro might be the way to go. @fehiepsi?
Current NumPyro SVI does not support that pattern but it might be able to do using JAX. Something like
def loss_fn(batch, params):
global_params, local_params = params
model_g = handlers.substitute(model, data=global_params)
guide_g = handlers.substitute(guide, data=global_params)
def get_loss_local(data, local_params):
model_l = handlers.substitute(model, data=local_params)
guide_l = handlers.substitute(guide, data=local_params)
loss = TraceELBO(model_l, guide_l, ...)
return loss
return jax.pmap(get_loss_local)(batch, params)
# then use jaxopt to optimize loss_fn over params: https://jaxopt.github.io/stable/stochastic.html#optax-solvers
though still seems to be a bit tricky to cover many usage cases (like when there are both global variables and local variables, we need to apply reduced sum at local variables).
Thanks for your thoughts!
My models always have both local and global variables. Do you see any way to define device split along the pyro plate? Maybe that could be provided as option in numpyro?
On Wed, 13 Apr 2022, 12:02 Du Phan, @.***> wrote:
Current NumPyro SVI does not support that pattern but it might be able to do using JAX. Something like
def loss_fn(batch, params): global_params, local_params = params model_g = handlers.substitute(model, data=global_params) guide_g = handlers.substitute(guide, data=global_params)
def get_loss_local(data, local_params): model_l = handlers.substitute(model, data=local_params) guide_l = handlers.substitute(guide, data=local_params) loss = TraceELBO(model_l, guide_l, ...) return loss
return jax.pmap(get_loss_local)(batch, params)
then use jaxopt to optimize loss_fn over params: https://jaxopt.github.io/stable/stochastic.html#optax-solvers
though still seems to be a bit tricky to cover many usage cases (like when there are both global variables and local variables, we need to apply reduced sum at local variables).
— Reply to this email directly, view it on GitHub https://github.com/scverse/scvi-tools/issues/1226#issuecomment-1097916251, or unsubscribe https://github.com/notifications/unsubscribe-auth/AFMFTV4CKS4EOBEXVKDTLYTVE2SURANCNFSM5F5UWUGQ . You are receiving this because you commented.Message ID: @.***>
Does pure data parallelism work with current numpyro and scvi-tools? (Loading different minibatches of data to different devices, 8 minibatches in parallel but different cells in each training iteration)
@adamgayoso sorry, do you have a recipe if I want to enable multi-GPU training of scVI model before it is released scvi-tools? I haven't done multi-GPU training before, so I'm asking where to start. Can I just apply a patch from #1357 ?
@adamgayoso If you ignore device-backed data loaders for now, what is the main roadblock to implementing the Pyro+horovod solution? https://pyro.ai/examples/svi_horovod.html
Does this boil down to implementing an equivalent to torch.utils.data.distributed.DistributedSampler
and modifying the training plan to add horovod
use? Or is there more to it?
Is the problem in writing a general solution that works for any model OR is the problem that this won't work for any model?
Also cc @macwiatrak for discussion
I expect that we only have minimal issues with non-Pyro models. This is due to updates in lightning that automatically wrap custom dataloader samplers like we have.
In the case of Pyro, we have a somewhat hacky solution that fuses it with a lightning module. I expect significantly more engineering work to get this right. A hacky solution might be quicker, but we shouldn't include that in this library.
To clarify, lightning should handle:
But this is in the default pytorch case. For Pyro, which lazily initializes params, the hacky solution would involve a callback that does some of the things you see in the linked pyro tutorial.
@adamgayoso I'd love to make this easier to do in Pyro (as @vitkl has requested). What's your timeline? Could we sync the week of the Jan 23 to figure out what would be needed on the Pyro side?
We don't have bandwidth to contribute much at the moment, but can review code. I think it's relatively straightforward to make this work in the nn.Module/PyroModule paradigm by altering what we call a TrainingPlan to use vanilla torch optimizers instead of Pyro optimizers. This will allow lightning to do almost all the work.
In other words, we can create a LowerLevelPyroTrainingPlan
, using this lower level pattern internally.
loss_fn = lambda model, guide: pyro.infer.Trace_ELBO().differentiable_loss(model, guide, X_train, y_train)
with pyro.poutine.trace(param_only=True) as param_capture:
loss = loss_fn(model, guide)
params = set(site["value"].unconstrained()
for site in param_capture.trace.nodes.values())
optimizer = torch.optim.Adam(params, lr=0.001, betas=(0.90, 0.999))
Then lightning will do everything it needs to do with handling backprop.
We already have a callback to run the param initialization, which can then be changed to reset the optimizer
@fritzo what is ELBOModule? This seems useful. I can maybe put up a draft PR with my idea
@eb8680 can tell you more about ELBOModule
It seems @vitkl has made it work via the new lower level trainingplan https://github.com/scverse/scvi-tools/pull/1845
It appears that the lightning "ddp_notebook"
strategy + https://github.com/scverse/scvi-tools/pull/1845 doesn't allow saving the trained model because lightning fails to load checkpoints saved by the worker GPU. Looks like when the state_dict
is loaded the model parameters don't exist. I tried creating a callback that would do a forward pass through the model on_fit_end
, on_train_end
and on_load_checkpoint
- however, none of that changed anything. Maybe this means that the strategy or the training plan needs to be modified but I don't understand what can be done next. Any tips would be appreciated @fritzo @eb8680 @adamgayoso
I have not tested the standard "ddp"
strategy in a script yet.
@martinkim0 Nice work! Great to have this supported.
Is it possible to combine this approach with DeviceBackedDataSplitter? I am interested in loading data subsets once - simply distributing different cells or spatial locations across GPUs.
DistributedSampler
that also takes as input the overall set of indices to pull data from (i.e., train set or test set or val set indices). Probably just need to add a few lines of code to init, call super init and the write a custom iter method