Closed qpjaada closed 3 years ago
Hi, do you mind sharing a pointer to the loss implementation of the said conventional optimizers?
As far as I know, the masked lm loss implementation is not part of the optimizers, they are part of the run_pretraining.py. The original BERT and LAMB implementation is at https://github.com/google-research/bert/blob/master/run_pretraining.py#L187. I don't think the loss calculation has been changed in the TF1 reference comparing to the original published model.
L216 is the evaluation loss and accuracy calculation. L224 has tf.metrics.accuracy and tf.metrics.mean on the masked_lm_accuracy and masked_lm_loss, tf.metrics.mean does a cross replica mean, and the denominator calculation is proper w.r.t. to the weights given to the tf.metrics.mean function. The end result is that, for evaluation, which batch a given sequence falls to won't affect the eval loss and accuracy.
For training, the masked_lm loss is at L276. The training loss calculation is indeed local, and which batch a given sequence falls into does affect the local loss. There is no difference between the reference and the [research model] (https://github.com/google-research/bert/blob/master/run_pretraining.py#L240) that the BERT and LAMB papers used w.r.t. this.
Thanks @sgpyc for the comments.
Anyway, since the preference is to stick as close to original BERT loss formulation, we can keep it that way. Closing this issue as resolved.
Currently, for BERT, the loss is normalized over tokens rather than sequences: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L216
This loss implementation is inconsistent with the conventional definition of SGD or LAMB optimizers. In the conventional usage of the LAMB optimizer, one applies the same loss-function for many sequences in a batch and then computes averages by weighing all sequences equally.
With the loss-per-token normalization as in reference, there are two unexpected side-effects:
The request here is to change the reference to use the loss-per-sequence formulation.