Closed atawari closed 10 months ago
Thanks for your good question. Since the output has been processed by L2 normalization, the MSE loss is equal to cosine distance, which was followed by MILAN. I do not change the name seriously is the code, so it may be a little misleading.
Make sense! Thank you.
PS: I ran into Nan loss when used MSE loss function and fp16 training, so "l2" is numerically stable too.
Great! For NAN, maybe you can try bf16, which is more stable.
In the github scripts, it uses "l2" loss which is a cosine alignment loss but in the paper, it is mentioned that it uses MSE loss. I am curious which is right?
Excerpt from the paper:
"... We select the corresponding unmasked token from the student and teacher, and compute the mean squared error (MSE) between the normalized pairs..."
From the scripts it uses
clip_loss_type = "l2"
andclip_norm_type="l2"
The above code uses cosine distance and not MSE. What is the loss that is used for the best stage-1 pre-trained checkpoint?