Hi guys, I am trying to implement SAM in MXNet and encounter two questions about grad_norm computation. When we compute SAM gradient e_w, should we calculate grad norm for each parameter, or just calculate the grad norm for all parameters at once? Hope for your advice.
Hi guys, I am trying to implement SAM in MXNet and encounter two questions about grad_norm computation. When we compute SAM gradient e_w, should we calculate grad norm for each parameter, or just calculate the grad norm for all parameters at once? Hope for your advice.