Hi there! I was watching your code and i have a doubt regarding the _grad_nom(self) method: from the SAM paper the grad norm used to compute epsilon_hat(w) defined by the following equation
supposing p=2 and then q=2 (1/p + 1/q = 1), would simply be the 2-norm of the gradients. So, shouldn't be the _grad_norm(self) method something like this:
@torch.no_grad()
def _grad_norm(self):
# put everything on the same device, in case of model parallelism
shared_device = self.param_groups[0]["params"][0].device
norm = torch.sqrt(
sum(
[
torch.sum(
torch.square(
p.grad
* (torch.abs(p.grad) if group["adaptive"] else 1.0)
)
).to(shared_device)
for group in self.param_groups
for p in group["params"]
if p.grad is not None
]
),
)
return norm
I've already tested it with SAM obtaining an accuracy of 97.17% and I'm now running tests with ASAM
Hi there! I was watching your code and i have a doubt regarding the
_grad_nom(self)
method: from the SAM paper the grad norm used to compute epsilon_hat(w) defined by the following equation supposing p=2 and then q=2 (1/p + 1/q = 1), would simply be the 2-norm of the gradients. So, shouldn't be the_grad_norm(self)
method something like this:I've already tested it with SAM obtaining an accuracy of 97.17% and I'm now running tests with ASAM