in line 537, forward_and_loss which includes the l2 regularization is used to compute grad this is then used in line 546 as input to dual_vector(grad).
I think this is not exactly correct given the original SAM paper. The state shouldn't be noised for the l2 regularization as it is now, but only for the cross-entropy loss. A separate gradient for the clean state should be computed for the l2 regularization and summed with the SAM gradient.
It seems to me that there might be a mistake in the way the noised state is computed in the current implementation. Specifically
sam/sam_jax/training_utils
in line 537,
forward_and_loss
which includes the l2 regularization is used to computegrad
this is then used in line 546 as input todual_vector(grad)
.I think this is not exactly correct given the original SAM paper. The state shouldn't be noised for the l2 regularization as it is now, but only for the cross-entropy loss. A separate gradient for the clean state should be computed for the l2 regularization and summed with the SAM gradient.
Is there something that I'm missing?