Closed kayuksel closed 3 years ago
Hello, I highly recommend having AGC as well as it is extremely helpful for the training stability.
def unitwise_norm(x): dim = [1, 2, 3] if x.ndim == 4 else 0 return torch.sum(x**2, dim=dim, keepdim= x.ndim > 1) ** 0.5 class AGC(opt.Optimizer): def __init__(self, params, optim: opt.Optimizer, clipping = 1e-2, eps = 1e-3): self.optim = optim defaults = dict(clipping=clipping, eps=eps) defaults = {**defaults, **optim.defaults} super(AGC, self).__init__(params, defaults) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: for p in group['params']: param_norm = torch.max(unitwise_norm( p), torch.tensor(group['eps']).to(p.device)) grad_norm = unitwise_norm(p.grad) max_norm = param_norm * group['clipping'] trigger = grad_norm > max_norm clipped = p.grad * (max_norm / torch.max(grad_norm, torch.tensor(1e-6).cuda())) p.grad.data.copy_(torch.where(trigger, clipped, p.grad)) self.optim.step(closure)
Hi @kayuksel - thanks for the feedback. Will add it today and update here.
Hi @kayuksel, AGC has been added! Included credit for you in the code as well: Thanks again
Hello, I highly recommend having AGC as well as it is extremely helpful for the training stability.