google-research / sam

Apache License 2.0
565 stars 72 forks source link

L2 regulariser and SAM #24

Open konstantinos-p opened 1 year ago

konstantinos-p commented 1 year ago

It seems to me that there might be a mistake in the way the noised state is computed in the current implementation. Specifically

sam/sam_jax/training_utils

in line 537, forward_and_loss which includes the l2 regularization is used to compute grad this is then used in line 546 as input to dual_vector(grad).

I think this is not exactly correct given the original SAM paper. The state shouldn't be noised for the l2 regularization as it is now, but only for the cross-entropy loss. A separate gradient for the clean state should be computed for the l2 regularization and summed with the SAM gradient.

Is there something that I'm missing?