Open austinv11 opened 9 months ago
@fritzo I might be willing to try to tackle this, do you have any opinions on how to expose the functionality to the end user?
Hi @austinv11, Thanks for offering. I'd guess there are a few ways we could support AMP in Pyro:
ELBOModule
explaining how it is created and why it is useful.ELBO.__call__
method to sphinx's :special-members: list herepyro.optim
's wrapper class. This seems intricate and more difficult to maintain though.Would you be interested in getting (1) or (2) working for yourself then contributing docs to show how you did it? We're happy to answer any questions about Pyro, but I think you know more about AMP than us 🙂
It looks like I might need to try option 3 since AMP-aware gradient scaling requires access to the optimizer's step() function.
I could try making it a boolean flag for PyroOptim
to enable AMP. Additionally, once that is enabled the user would need to manually use Pytorch's autocast
context manager within their models.
But I could see most users wanting to just activate AMP for their entire model rather than just specific portions of code. Do you think it might be worth adding a new ELBO function that autocasts the entire model for the user?
Let me try again to persuade you towards options (1) or (2) 😄, admitting I don't know your details or how AMP works.
Back in the early days of Pyro we decided to wrap PyTorch's optimizer classes so we could have more control over dynamically created parameters. In practice this made Pyro's optimization idioms incompatible with other frameworks build on top of PyTorch, e.g. lightning, horovod, AMP, new higher-order optimizers. To work around this incompatibility we've since added ways to compute differentiable losses in Pyro so that optimization can be done entirely using torch idioms, without ever using pyro.optim
.
For example instead of the original pyro-idiomatic optimization
def model(args):
...
guide = AutoNormal(model)
elbo = Trace_ELBO()
optim = pyro.optim.Adam(...) # <---- pyro idioms
svi = SVI(model, guide, optim, elbo)
for step in range(...):
svi.step(args)
you can use torch-idiomatic optimizers
class Model(PyroModule):
def forward(args):
...
model = Model()
guide = AutoNormal(model)
elbo = Trace_ELBO()
loss_fn = elbo(model, guide)
optim = torch.optim.Adam(elbo.parameters(), ...) # <---- torch idioms
for step in range(...):
optimizer.zero_grad()
loss = loss_fn(args)
loss.backward()
optimizer.step() # <---- Can we use AMP here?
What I'm hoping is that by switching to torch-native optimizers it will be easy/trivial to support AMP.
That said, we'd still be open to adding AMP support to pyro.optim if you can find a simple maintainable way to do so 🙂.
Ah, I see what you mean. Am I correct in understanding that this wouldn't be compatible with the SVI trainer and would require using PyroModules then?
That is also incompatible with models/guides that dynamically create parameters during training, if I understand correctly.
@austinv11 @ilia-kats correct.
Issue Description
Better support for mixed precision training would be extremely helpful, at least for SVI. I can manually cast data into
float16
orbfloat16
but I am unable to leverage PyTorch's automatic mixed precision training. This is because it requires the use of theGradScaler
class during the optimization loop to properly scale gradients in a mixed-precision-aware manner. See the documentation for more info: https://pytorch.org/docs/stable/amp.htmlIt would be nice to have support for using this class within pyro optimizers to allow for amp support.