Closed mcarilli closed 4 years ago
Hi there! No we don't always monitor the PyTorch slack, so thanks for reaching out here! The problem is that it was designed by basically altering every line of code of the base training loop, which is completely incompatible with the way fastai usually tweaks the training loop (we change the training loop via callbacks and always keep the same core), so I'm not sure we will actually be able to use it (the same way we were using the helper functions from apex but not actually apex).
I have to dig behind the front-end layer to see how it could fit with our callback way of doing things but we have a lot on our plates right now, so it won't be before a few weeks. Can we continue the discussion on the fastai forum? Closing this issue since we keep issues for bugs in the library only, features are usually discussed there (and we may find people able to help with the conversion).
Ah! I just got a crazy idea that would allow us to use the tweaked training loop! Will look at this over the weekend and report back here (note that his won't be for v1 in any case so we can leave the issue in the fastai repo closed).
As the PR made it's way in the nightlies yet?
Great!! It should be in nightlies. But FYI, unrelated to amp, torch.topk
is broken with fp16 on master right now. Fix is in flight (https://github.com/pytorch/pytorch/pull/35734, https://github.com/pytorch/pytorch/pull/35986) but topk is pretty common so i'd give master a week to stabilize.
If you want, I can describe the design in detail on slack. You may find it's less invasive than you fear (certainly less than Apex Amp). I can also describe the lower-level torch functions the documented API calls into, if those are more suitable for fastai's design.
should I still move this to the forum? if no response by EOD Monday, i'll assume that's a yes, open a topic in the fastai users
+ fastai-v2
catogories, and leave a link here.
There is a topic there already. I have added the basic support in this commit, but it will need some additional work for some other callbacks like gradient clipping. I'll start testing it works as nicely as what we have next week.
Fantastic! I looked over the code and my first question is: why does it need to be so complex? With the native integration, the model should not be cast to .half(). model params are FP32 and == params stepped by the optimizer, as with non-mixed-precision training. There's no need to maintain master params distinct from the model params (which was an irritating, confusing, and brittle aspect of our initial mixed precision formulation. We're learning as we go...).
These
https://github.com/fastai/fastai2/commit/60d63c3e27fb8ea5cf1e8d105b862bceb8fd612d#diff-0477c8e3e5e8d3059dafd449dfd17fe1R157-R164
https://github.com/fastai/fastai2/commit/60d63c3e27fb8ea5cf1e8d105b862bceb8fd612d#diff-0477c8e3e5e8d3059dafd449dfd17fe1R179
(which look correct) are all you need to do in the training loop, and their effects are local to the body of each iteration (aside from scaler
itself, which updates its internal scale over time).
You also need to construct the GradScaler instance once that the beginning of training (which you are).
Warning about something I did not see:
For bitwise accurate saving/restoring, you should include scaler.state_dict()
and scaler.load_state_dict()
alongside the usual model/optimizer state dict handling.
Still not sure what your preferred mode of communication is :P Should I ship all this discussion to the forum? Comment on the WIP PR? Continue here?
Thanks for your comments!
All that is before mixed_precision_one_batch
is the current implementation of mixed precision which we won't remove until there has been a release of PyTorch and we have confirmed native mixed precision work. So the implementation of native mixed precision is just the altered one_batch
function and the callback, you can ignore the rest of the callback.fp16
module.
We have not implemented any saving/restoring on the callbacks yet. We just rely on pickling the whole thing for now.
For communication, the preferred mode is the forum since that is where our community is and will see the whole thing (I mostly pinged you on the issue to let you know what I had done since I did not see you on the forum).
Native automatic mixed precision support (
torch.cuda.amp
) is finally merged: https://pytorch.org/docs/master/amp.html https://pytorch.org/docs/master/notes/amp_examples.htmlApex Amp has many known pain points (extension builds, forward/backward compatibilty, DataParallel support, flaky checkpointing, i don’t even know if it can be hacked to handle double backward/gradient penalty, others…).
torch.cuda.amp
fixes all these, the interface is more flexible and intuitive, and the tighter integration with pytorch brings more future optimizations into scope.I think the
torch.cuda.amp
API is a good fit for a higher-level library like fastai because its style is more functional (as in, it doesn't statefully alter anything outside itself). The necessarytorch.cuda.amp
calls don't have silent/weird effects elsewhere.If you want to talk about adding
torch.cuda.amp
to fastai, with an eye towards it becoming the future-proof source of mixed precision, message me on Pytorch slack anytime. I pinged you there as well but I’m not sure if you monitor it habitually.