Closed mcarilli closed 2 years ago
Thank you very much for reaching out to us. I'm usually not on slack but if you prefer discussing there, I could go there (I believe I had an invite that is tied to an account I can't access anymore).
I agree, this is a very cool feature and it would be very useful for skorch to support it.
From what I can tell, we would need to implement two things:
1) torch.cuda.amp.autocast
for module calls
2) GradScaler.scale
for backward calls
Is it true that the two should always be used in conjunction or could you imagine a situation where they aren't?
Without going to deep into implementation details, I see three places in skorch that would need to be adjusted. First, the NeuralNet.infer
and NeuralNet.get_loss
methods, through which we call the module
and the criterion. Here we need autocast
. Second, there is the NeuralNet.train_step_single
, which would need to scale the gradients. The simplest implementation would just use autocast
and scale
for those two, with the option for the user to enable or disable them (thankfully, they both have an enabled
argument, so this should be easy).
I wonder, though, if this covers all possible use cases. For instance, we could just offload the autocast
work to the users, since they're already in charge of implementing the module. That means they could just the autocast
decorator or context manager in their own code. One disadvantage here is that using ready-made modules (from, say, torchvision) would be more cumbersome.
In summary, I think what this boils down to is how flexible do you believe this system should be? If it's an on/off thing, the first solution will do, if not, we'd need to think of a better way.
@ottonemo @thomasjpfan I think an interesting implementation challenge this poses is the coupling of the two functionalities. Say a user overrides train_step_single
and forgets to implement the scaling (or just has old code lying around from when this didn't exist), but then enables autocast
, they could unwittingly run into the underflow condition. Ideally, we can somehow make sure that his won't happen.
I don't mind talking here if you don't mind non-real-time.
GradScaler.scale for backward calls
You also need to step the optimizer via scaler.step(optimizer)
, and, if any optimizers stepped this iteration, you need to call scaler.update()
after the iteration's optimizer steps are done.
For bitwise accurate saving/restoring, you should also include scaler.state_dict
and scaler.load_state_dict
alongside the usual model/optimizer state_dict/load_state_dict calls.
Is it true that the two should always be used in conjunction or could you imagine a situation where they aren't?
Inference/eval is one situation you might use autocast but not GradScaler, because you're not computing gradients.
In summary, I think what this boils down to is how flexible do you believe this system should be?
I think an API flag that enables autocast for (by default) the entirety of forward pass(es) + loss computation(s) covers the common case and doesn't deprive the user of flexibility. The autocast context manager is nestable. Users can rely on/understand that supplying an enable_amp
arg (or whatever) to Skorch enables autocast for all of forward+loss. If they know certain regions of their model don't want autocasting, they can write those regions under with autocast(enabled=False):
to override the surrounding context Skorch imposed.
Your call though.
The autocast context manager is nestable.
This is fantastic, I agree that this means we can just proceed with a global amp option without depriving the user of flexibility.
Inference/eval is one situation you might use autocast but not GradScaler, because you're not computing gradients.
At the same time, the scaler shouldn't get in the way in this case, so we're fine.
I assume that there will always be only GradScaler
or should we leave room for replacements?
Here is a list of things that would need to be implemented so that we don't forget:
NeuralNet
to enable amp (amp_enabled
)autocast
inside infer
and get_loss
methodscaler.scale
for backward calls (outside of autocast
context)scaler.step(optimizer)
and scaler.update()
scaler.state_dict
/ scaler.load_state_dict
wherever you save/restore checkpoints.GradScaler
and to pass arguments using dunder notation: net.set_params(grad_scaler__growth_factor=4)
.GradientNormClipping
to work with scaled gradientsIf anyone has something to add, I'll update the list.
I assume that there will always be only GradScaler or should we leave room for replacements?
what do you mean?
what do you mean?
I mean if it's good enough to hard code GradScaler
or if it's a realistic possibility that a user might want to use their own scaler version.
GradScaler's
default constructor args don't need to be altered or tuned for the vast majority of cases. They didn't need to be altered or tuned for any of our local tests so far (40ish models across several application domains). However, I think you should allow the possibility that someone could supply their own instance, eg with an optional scaler=None
arg after which you say
scaler = torch.cuda.amp.GradScaler() if scaler is None else scaler
.
Alright, thanks, I added the requirement that the user should be able to modify the grad scaler. I'll start working on it as soon as I have time. Since I currently don't have a setup to test this, I would need someone to do this for me.
There's no particular rush. topk
on master was recently broken with float16
tensors (fixed yesterday, caused incidental failures in some amp-enabled scripts that used topk). Also I just learned amp
's interaction with cudnn RNNs may be broken, because my casts don't reflatten the weights (cell-based RNNs should be ok). I'll fix this soon, it's a high priority. However, these are under-the-hood fixes: the user-facing API is stable.
My point is, you can confidently hack on the interaction of your stuff with the Amp frontend at your leisure. If it takes a few weeks, that's not a bad thing (smoother rollout as we find and fix backend rough edges, and we won't change the Python API underneath you).
I currently don't have a setup to test this
do you mean you don't have access to Volta or Turing machines?
Thanks for the heads-up. As you mentioned, as long as the API is stable, we should be able to work on implementing the feature. We will surely not rush, which is not how we work on this project :)
do you mean you don't have access to Volta or Turing machines?
Not even to an NVIDIA GPU at all.
Not even to an NVIDIA GPU at all.
I'll talk to my group, maybe we can get you access to testing resources. Can't promise anything but I'll ask.
https://github.com/skorch-dev/skorch/issues/611#issuecomment-610031888
+scaler.state_dict
/scaler.load_state_dict
wherever you save/restore checkpoints.
Other than that, skim the remaining examples to see if any cases there are relevant.
Also, for cases with multiple convergence runs (multiple networks trained to convergence separately) in the same script, you should use a fresh GradScaler instance for each run (https://github.com/NVIDIA/apex/issues/439#issuecomment-610028282).
I'll talk to my group, maybe we can get you access to testing resources. Can't promise anything but I'll ask.
That's super nice of you. Of course, I'll keep my expectations down :)
+
scaler.state_dict
/scaler.load_state_dict
wherever you save/restore checkpoints.
Thanks, added.
@mcarilli Is my understanding correct that the feature is not included in PyTorch 1.5 yet?
That's correct, only in master and nightly packages now. But it will be in 1.6.
As an additional resource, this article also seems to be useful.
Regarding the state, I would probably start working on this once PyTorch 1.6 gets released (I expect the number of people using nightly is rather small). As to how to test this, maybe one of the google collab options is compatible with AMP.
Another difficulty that comes to mind is that we want to support older PyTorch versions. I updated the requirements to take this into account. We should give a warning when a user enables AMP but runs on PyTorch < 1.6.
The challenge becomes bigger since ideally, we would like to have GradScaler
as a direct argument to NeuralNet
, but GradScaler
cannot be imported in older PyTorch versions. There are ways to deal with this but ideally, we can avoid adding a bunch of code in the style of if PYTORCH_VERSION < (1, 6): do-this; else: do-that
.
I think warning if the user requests Amp but Amp is not available is a good approach.
amp_available = (hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")
is a more direct alternative to checking the version number.
One way to avoid if/else statements further on in the code is to define your own dummy GradScaler class with no-op methods and instantiate that if not amp_available
. Downstream code could then call scaler.xyz
unconditionally. Still moderately inconvenient because it must have all the same methods as the real GradScaler. Hard to know what the easiest approach is until you start hacking on it.
At long last, I think we can close this issue. Although we have not directly implemented AMP in skorch, we now support it via huggingface's accelerate library, see #826
Native automatic mixed precision support (
torch.cuda.amp
) is now in master: https://pytorch.org/docs/master/amp.html https://pytorch.org/docs/master/notes/amp_examples.htmlNot sure if you ever tried Nvidia's (our) experimental Apex Amp, but I know it has many pain points (extension builds, forward/backward compatibilty, DataParallel support, flaky checkpointing, i don’t even know if it can be hacked to handle double backward/gradient penalty, others…).
torch.cuda.amp
fixes all these, the interface is more flexible and intuitive, and the tighter integration with pytorch brings more future optimizations into scope.I think the
torch.cuda.amp
API is a good fit for a higher-level library because its style is more functional (as in, it doesn't statefully alter anything outside itself). The necessarytorch.cuda.amp
calls don't have silent/weird effects elsewhere.If you want to talk about adding
torch.cuda.amp
to Skorch, with an eye towards it becoming the future-proof source of mixed precision, message me on Pytorch slack anytime (or ask me for invites if you're not signed up). I'll check this issue periodically but I'm on Pytorch slack a greater fraction of the time than I care to admit.