davda54 / sam

SAM: Sharpness-Aware Minimization (PyTorch)
MIT License
1.77k stars 197 forks source link

Grad norm computation #26

Closed belerico closed 3 years ago

belerico commented 3 years ago

Hi there! I was watching your code and i have a doubt regarding the _grad_nom(self) method: from the SAM paper the grad norm used to compute epsilon_hat(w) defined by the following equation image supposing p=2 and then q=2 (1/p + 1/q = 1), would simply be the 2-norm of the gradients. So, shouldn't be the _grad_norm(self) method something like this:

@torch.no_grad()
def _grad_norm(self):
    # put everything on the same device, in case of model parallelism
    shared_device = self.param_groups[0]["params"][0].device
    norm = torch.sqrt(
        sum(
            [
                torch.sum(
                    torch.square(
                        p.grad
                        * (torch.abs(p.grad) if group["adaptive"] else 1.0)
                    )
                ).to(shared_device)
                for group in self.param_groups
                for p in group["params"]
                if p.grad is not None
            ]
        ),
    )
    return norm

I've already tested it with SAM obtaining an accuracy of 97.17% and I'm now running tests with ASAM

davda54 commented 3 years ago

Hi, this is exactly what torch.norm(tensor, p=2) is computing. See also this related issue: #16

belerico commented 3 years ago

Sorry, my bad! I didn't see the issue and the equality