FluxML / Optimisers.jl

Optimisers.jl defines many standard optimisers and utilities for learning loops.
https://fluxml.ai/Optimisers.jl
MIT License
75 stars 22 forks source link

Rule for mixed precision training #152

Open CarloLucibello opened 1 year ago

CarloLucibello commented 1 year ago

Implements a fundamental strategy for training big models. In this implementation, the high precision weights are part of the optimiser's state.

The optimiser introduced here should be coupled to a loss scaling strategy (not in this PR, and probably not ever in this repo) to obtain robust mixed precision training.

PR Checklist

darsnack commented 1 year ago

This will need a custom adjust too

CarloLucibello commented 1 year ago

Good to go?

mcabbott commented 1 year ago

Would like to ask for some time to read this closely.

Haven't understood why it doesn't use OptimiserChain. Not sure that exposing MixedPrecision{Float32} as the official way to specify the type is so nice.

darsnack commented 1 year ago

It’s unfortunate that the API here doesn’t use OptimiserChain like AccumGrad, but I think it is unavoidable. We need to invoke the inner optimizer’s apply! with higher precision which we could do by promoting x and dx as the first thing in the chain (along with #151). But then we still need subtract! to happen at higher precision and the result synced back to the high precision copy of x in MixedPrecision‘s state. So doing it would require a big change to update/subtract!.

Maybe I’m not seeing it, and indeed we should think about this feature carefully before releasing it. It’s an important one to get right.

It makes me think that OptimiserChain is not the best API in general. Yes, it makes writing the rule easier. But it doesn’t work well for cases like this, and it has a potential foot gun where you can place the rule in the wrong spot in a chain. Wrapping seems more intuitive as an API (“whatever this rule that I’m wrapping does, do that in precision T” vs. grokking how the update transforms in a chain).

darsnack commented 1 year ago

Apologies for the train of comments.

I guess one distinction here is a gradient transformation vs. an optimizer transformation. Something like AccumGrad is a gradient transformation. But something like a scheduler or MixedPrecision is an optimizer transformation—where we transform/augment an optimizer's state, parameters, or invocation before running the rule. The arrow goes in one direction. You can write a gradient transformation rule as an optimizer transformation rule, but not vice-versa.

CarloLucibello commented 1 year ago

Not sure that exposing MixedPrecision{Float32} as the official way to specify the type is so nice.

Can be MixedPrecision(Float32, opt) instead

CarloLucibello commented 1 year ago

@mcabbott good to go?

mcabbott commented 1 year ago

I have questions but haven't had time to read up more.

First, this setup starts with the low-precision model, and temporarily stores a higher-precision copy for the purpose of accumulation without overflow. Is this standard? There's no very easy way to get back the high-precision model.

Somehow I thought the default was the reverse, to start and end with the high-precision one, and treat the low-precision model as a temporary step for cheaper gradients. This could not, I think, be implemented as a rule here.

Second, I thought it was conventional to scale the loss used for low-precision steps, and then scale the gradients. There's no sign of that here. Am I mis-remembering? Scaling the gradients could be done by composing with Descent but I'm not entirely sure that's the right place. And perhaps if it's standard it should be made easy.

mcabbott commented 1 year ago

Just to sketch another possibility, Flux could instead wrap pairs of low+high precision copies of the same model:

bimodel = MixedPrec(Float16, model32)  # makes a Float16 copy, stores a scale for loss.
bimodel(x)  # calls variant matching eltype(x)? Default Float16?

opt_state = setup(Adam(), bimodel)  # special method?
gs = gradient(loss, bimodel)  # this could be a special method which scales the loss?
update!(opt_state, bimodel, gs[1]) # this knows about the scale, updates both halves.