pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
80.25k stars 21.55k forks source link

[new feature] Adaptive gradient clipping #57503

Open jamestwebber opened 3 years ago

jamestwebber commented 3 years ago

🚀 Feature

As described in High-Performance Large-Scale Image Recognition Without Normalization from Brock et al., implemented here (with jax).

I know you typically only implement things that have gained some moderate amount of usage. So AGC may not qualify, but I thought I'd open a feature request to see if there was interest. It's a fairly low-effort feature that might be broadly useful.

Motivation

As described in the paper, AGC is a method to clip gradients based on how large they are compared to the parameter values, as opposed to using a fixed cutoff on the norm.

Brock et al. used this (along with other methods) to train NFNets, but it has general applications as an alternative to the other clipping methods. Seems to work well in my hands at least.

Pitch

It's actually super simple to write. I implemented a draft version, adapted from @vballoli's code here but written the style of nn.util.clip_grad_norm_

Implementation inside ```python import torch from torch._six import inf from typing import Union, Iterable _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] def clip_grad_agc_(parameters: _tensor_or_tensors, clipping: float, norm_type: float = 2.0) -> torch.Tensor: r"""Adaptive gradient clipping, implemented as a util à la clip_grad_norm_ The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized clipping (float or int): clipping factor used to compute max_norm norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm. Returns: Total norm of the parameters (viewed as a single vector). """ if isinstance(parameters, torch.Tensor): parameters = [parameters] parameters = [p for p in parameters if p.grad is not None] clipping = float(clipping) norm_type = float(norm_type) if len(parameters) == 0: return torch.tensor(0.) device = parameters[0].grad.device if norm_type == inf: param_norm = max(p.detach().abs().max().to(device) for p in parameters) grad_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) else: param_norm = torch.norm( torch.stack([torch.norm(p.detach(), norm_type).to(device) for p in parameters]), norm_type ) grad_norm = torch.norm( torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type ) max_norm = param_norm * clipping clip_coef = max_norm / (grad_norm + 1e-6) if clip_coef < 1: for p in parameters: p.grad.detach().mul_(clip_coef.to(p.grad.device)) return grad_norm ``` Note: the paper's implementation (and @vballoli's) only support L2 norm, but I figured why not?

Alternatives

It might be too early to add this, I don't know. I just thought I'd open the issue because it makes my life easier if it's built in to PyTorch 😄

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @Lezcano @mruberry @jbschlosser

jbschlosser commented 3 years ago

Hey @jamestwebber, thanks for the great suggestion and implementation! If this becomes popular enough, we'd be happy to accept it.