skorch-dev / skorch

A scikit-learn compatible neural network library that wraps PyTorch
BSD 3-Clause "New" or "Revised" License
5.88k stars 390 forks source link

Native automatic mixed precision for Skorch #611

Closed mcarilli closed 2 years ago

mcarilli commented 4 years ago

Native automatic mixed precision support (torch.cuda.amp) is now in master: https://pytorch.org/docs/master/amp.html https://pytorch.org/docs/master/notes/amp_examples.html

Not sure if you ever tried Nvidia's (our) experimental Apex Amp, but I know it has many pain points (extension builds, forward/backward compatibilty, DataParallel support, flaky checkpointing, i don’t even know if it can be hacked to handle double backward/gradient penalty, others…). torch.cuda.amp fixes all these, the interface is more flexible and intuitive, and the tighter integration with pytorch brings more future optimizations into scope.

I think the torch.cuda.amp API is a good fit for a higher-level library because its style is more functional (as in, it doesn't statefully alter anything outside itself). The necessary torch.cuda.amp calls don't have silent/weird effects elsewhere.

If you want to talk about adding torch.cuda.amp to Skorch, with an eye towards it becoming the future-proof source of mixed precision, message me on Pytorch slack anytime (or ask me for invites if you're not signed up). I'll check this issue periodically but I'm on Pytorch slack a greater fraction of the time than I care to admit.

BenjaminBossan commented 4 years ago

Thank you very much for reaching out to us. I'm usually not on slack but if you prefer discussing there, I could go there (I believe I had an invite that is tied to an account I can't access anymore).

I agree, this is a very cool feature and it would be very useful for skorch to support it.

From what I can tell, we would need to implement two things:

1) torch.cuda.amp.autocast for module calls 2) GradScaler.scale for backward calls

Is it true that the two should always be used in conjunction or could you imagine a situation where they aren't?

Without going to deep into implementation details, I see three places in skorch that would need to be adjusted. First, the NeuralNet.infer and NeuralNet.get_loss methods, through which we call the module and the criterion. Here we need autocast. Second, there is the NeuralNet.train_step_single, which would need to scale the gradients. The simplest implementation would just use autocast and scale for those two, with the option for the user to enable or disable them (thankfully, they both have an enabled argument, so this should be easy).

I wonder, though, if this covers all possible use cases. For instance, we could just offload the autocast work to the users, since they're already in charge of implementing the module. That means they could just the autocast decorator or context manager in their own code. One disadvantage here is that using ready-made modules (from, say, torchvision) would be more cumbersome.

In summary, I think what this boils down to is how flexible do you believe this system should be? If it's an on/off thing, the first solution will do, if not, we'd need to think of a better way.

@ottonemo @thomasjpfan I think an interesting implementation challenge this poses is the coupling of the two functionalities. Say a user overrides train_step_single and forgets to implement the scaling (or just has old code lying around from when this didn't exist), but then enables autocast, they could unwittingly run into the underflow condition. Ideally, we can somehow make sure that his won't happen.

mcarilli commented 4 years ago

I don't mind talking here if you don't mind non-real-time.

GradScaler.scale for backward calls

You also need to step the optimizer via scaler.step(optimizer), and, if any optimizers stepped this iteration, you need to call scaler.update() after the iteration's optimizer steps are done.

For bitwise accurate saving/restoring, you should also include scaler.state_dict and scaler.load_state_dict alongside the usual model/optimizer state_dict/load_state_dict calls.

Is it true that the two should always be used in conjunction or could you imagine a situation where they aren't?

Inference/eval is one situation you might use autocast but not GradScaler, because you're not computing gradients.

In summary, I think what this boils down to is how flexible do you believe this system should be?

I think an API flag that enables autocast for (by default) the entirety of forward pass(es) + loss computation(s) covers the common case and doesn't deprive the user of flexibility. The autocast context manager is nestable. Users can rely on/understand that supplying an enable_amp arg (or whatever) to Skorch enables autocast for all of forward+loss. If they know certain regions of their model don't want autocasting, they can write those regions under with autocast(enabled=False): to override the surrounding context Skorch imposed.

Your call though.

BenjaminBossan commented 4 years ago

The autocast context manager is nestable.

This is fantastic, I agree that this means we can just proceed with a global amp option without depriving the user of flexibility.

Inference/eval is one situation you might use autocast but not GradScaler, because you're not computing gradients.

At the same time, the scaler shouldn't get in the way in this case, so we're fine.

I assume that there will always be only GradScaler or should we leave room for replacements?

Here is a list of things that would need to be implemented so that we don't forget:

If anyone has something to add, I'll update the list.

mcarilli commented 4 years ago

I assume that there will always be only GradScaler or should we leave room for replacements?

what do you mean?

BenjaminBossan commented 4 years ago

what do you mean?

I mean if it's good enough to hard code GradScaler or if it's a realistic possibility that a user might want to use their own scaler version.

mcarilli commented 4 years ago

GradScaler's default constructor args don't need to be altered or tuned for the vast majority of cases. They didn't need to be altered or tuned for any of our local tests so far (40ish models across several application domains). However, I think you should allow the possibility that someone could supply their own instance, eg with an optional scaler=None arg after which you say scaler = torch.cuda.amp.GradScaler() if scaler is None else scaler.

BenjaminBossan commented 4 years ago

Alright, thanks, I added the requirement that the user should be able to modify the grad scaler. I'll start working on it as soon as I have time. Since I currently don't have a setup to test this, I would need someone to do this for me.

mcarilli commented 4 years ago

There's no particular rush. topk on master was recently broken with float16 tensors (fixed yesterday, caused incidental failures in some amp-enabled scripts that used topk). Also I just learned amp's interaction with cudnn RNNs may be broken, because my casts don't reflatten the weights (cell-based RNNs should be ok). I'll fix this soon, it's a high priority. However, these are under-the-hood fixes: the user-facing API is stable.

My point is, you can confidently hack on the interaction of your stuff with the Amp frontend at your leisure. If it takes a few weeks, that's not a bad thing (smoother rollout as we find and fix backend rough edges, and we won't change the Python API underneath you).

I currently don't have a setup to test this

do you mean you don't have access to Volta or Turing machines?

BenjaminBossan commented 4 years ago

Thanks for the heads-up. As you mentioned, as long as the API is stable, we should be able to work on implementing the feature. We will surely not rush, which is not how we work on this project :)

do you mean you don't have access to Volta or Turing machines?

Not even to an NVIDIA GPU at all.

mcarilli commented 4 years ago

Not even to an NVIDIA GPU at all.

I'll talk to my group, maybe we can get you access to testing resources. Can't promise anything but I'll ask.

https://github.com/skorch-dev/skorch/issues/611#issuecomment-610031888 +scaler.state_dict/scaler.load_state_dict wherever you save/restore checkpoints.

Other than that, skim the remaining examples to see if any cases there are relevant.

Also, for cases with multiple convergence runs (multiple networks trained to convergence separately) in the same script, you should use a fresh GradScaler instance for each run (https://github.com/NVIDIA/apex/issues/439#issuecomment-610028282).

BenjaminBossan commented 4 years ago

I'll talk to my group, maybe we can get you access to testing resources. Can't promise anything but I'll ask.

That's super nice of you. Of course, I'll keep my expectations down :)

+scaler.state_dict/scaler.load_state_dict wherever you save/restore checkpoints.

Thanks, added.

BenjaminBossan commented 4 years ago

@mcarilli Is my understanding correct that the feature is not included in PyTorch 1.5 yet?

mcarilli commented 4 years ago

That's correct, only in master and nightly packages now. But it will be in 1.6.

BenjaminBossan commented 4 years ago

As an additional resource, this article also seems to be useful.

Regarding the state, I would probably start working on this once PyTorch 1.6 gets released (I expect the number of people using nightly is rather small). As to how to test this, maybe one of the google collab options is compatible with AMP.

Another difficulty that comes to mind is that we want to support older PyTorch versions. I updated the requirements to take this into account. We should give a warning when a user enables AMP but runs on PyTorch < 1.6.

The challenge becomes bigger since ideally, we would like to have GradScaler as a direct argument to NeuralNet, but GradScaler cannot be imported in older PyTorch versions. There are ways to deal with this but ideally, we can avoid adding a bunch of code in the style of if PYTORCH_VERSION < (1, 6): do-this; else: do-that.

mcarilli commented 4 years ago

I think warning if the user requests Amp but Amp is not available is a good approach.

amp_available = (hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast")

is a more direct alternative to checking the version number.

One way to avoid if/else statements further on in the code is to define your own dummy GradScaler class with no-op methods and instantiate that if not amp_available. Downstream code could then call scaler.xyz unconditionally. Still moderately inconvenient because it must have all the same methods as the real GradScaler. Hard to know what the easiest approach is until you start hacking on it.

BenjaminBossan commented 2 years ago

At long last, I think we can close this issue. Although we have not directly implemented AMP in skorch, we now support it via huggingface's accelerate library, see #826