pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

[FR] Support Automatic Mixed Precision training #3316

Open austinv11 opened 9 months ago

austinv11 commented 9 months ago

Issue Description

Better support for mixed precision training would be extremely helpful, at least for SVI. I can manually cast data into float16 or bfloat16 but I am unable to leverage PyTorch's automatic mixed precision training. This is because it requires the use of the GradScaler class during the optimization loop to properly scale gradients in a mixed-precision-aware manner. See the documentation for more info: https://pytorch.org/docs/stable/amp.html

It would be nice to have support for using this class within pyro optimizers to allow for amp support.

austinv11 commented 9 months ago

@fritzo I might be willing to try to tackle this, do you have any opinions on how to expose the functionality to the end user?

fritzo commented 9 months ago

Hi @austinv11, Thanks for offering. I'd guess there are a few ways we could support AMP in Pyro:

  1. Use Pyro's ELBOModule to construct a differentiable loss function as in the lightning tutorial, then do standard PyTorch training with AMP. I think Pyro's code already supports this, we'd just need improved documentation and maybe an example:
    • Add a docstring to ELBOModule explaining how it is created and why it is useful.
    • Add ELBO.__call__ method to sphinx's :special-members: list here
    • Add an examples/svi_amp.py similar to examples/svi_lightning.py
  2. Do something similar, but with the Trace_ELBO.differentiable_loss() method.
  3. Add more native AMP support to pyro.optim's wrapper class. This seems intricate and more difficult to maintain though.

Would you be interested in getting (1) or (2) working for yourself then contributing docs to show how you did it? We're happy to answer any questions about Pyro, but I think you know more about AMP than us 🙂

austinv11 commented 9 months ago

It looks like I might need to try option 3 since AMP-aware gradient scaling requires access to the optimizer's step() function.

I could try making it a boolean flag for PyroOptim to enable AMP. Additionally, once that is enabled the user would need to manually use Pytorch's autocast context manager within their models.

But I could see most users wanting to just activate AMP for their entire model rather than just specific portions of code. Do you think it might be worth adding a new ELBO function that autocasts the entire model for the user?

fritzo commented 8 months ago

Let me try again to persuade you towards options (1) or (2) 😄, admitting I don't know your details or how AMP works.

Back in the early days of Pyro we decided to wrap PyTorch's optimizer classes so we could have more control over dynamically created parameters. In practice this made Pyro's optimization idioms incompatible with other frameworks build on top of PyTorch, e.g. lightning, horovod, AMP, new higher-order optimizers. To work around this incompatibility we've since added ways to compute differentiable losses in Pyro so that optimization can be done entirely using torch idioms, without ever using pyro.optim.

For example instead of the original pyro-idiomatic optimization

def model(args):
    ...
guide = AutoNormal(model)
elbo = Trace_ELBO()
optim = pyro.optim.Adam(...)  # <---- pyro idioms
svi = SVI(model, guide, optim, elbo)
for step in range(...):
    svi.step(args)

you can use torch-idiomatic optimizers

class Model(PyroModule):
    def forward(args):
        ...
model = Model()
guide = AutoNormal(model)
elbo = Trace_ELBO()
loss_fn = elbo(model, guide)
optim = torch.optim.Adam(elbo.parameters(), ...)  # <---- torch idioms
for step in range(...):
    optimizer.zero_grad()
    loss = loss_fn(args)
    loss.backward()
    optimizer.step()  # <---- Can we use AMP here?

What I'm hoping is that by switching to torch-native optimizers it will be easy/trivial to support AMP.

That said, we'd still be open to adding AMP support to pyro.optim if you can find a simple maintainable way to do so 🙂.

austinv11 commented 8 months ago

Ah, I see what you mean. Am I correct in understanding that this wouldn't be compatible with the SVI trainer and would require using PyroModules then?

ilia-kats commented 8 months ago

That is also incompatible with models/guides that dynamically create parameters during training, if I understand correctly.

fritzo commented 8 months ago

@austinv11 @ilia-kats correct.