mlcommons / training

Reference implementations of MLPerf™ training benchmarks
https://mlcommons.org/en/groups/training
Apache License 2.0
1.6k stars 553 forks source link

[BERT] Change the masked LM loss implementation in BERT reference #498

Closed qpjaada closed 3 years ago

qpjaada commented 3 years ago

Currently, for BERT, the loss is normalized over tokens rather than sequences: https://github.com/mlcommons/training/blob/master/language_model/tensorflow/bert/run_pretraining.py#L216

This loss implementation is inconsistent with the conventional definition of SGD or LAMB optimizers. In the conventional usage of the LAMB optimizer, one applies the same loss-function for many sequences in a batch and then computes averages by weighing all sequences equally.

With the loss-per-token normalization as in reference, there are two unexpected side-effects:

The request here is to change the reference to use the loss-per-sequence formulation.

sgpyc commented 3 years ago

Hi, do you mind sharing a pointer to the loss implementation of the said conventional optimizers?

As far as I know, the masked lm loss implementation is not part of the optimizers, they are part of the run_pretraining.py. The original BERT and LAMB implementation is at https://github.com/google-research/bert/blob/master/run_pretraining.py#L187. I don't think the loss calculation has been changed in the TF1 reference comparing to the original published model.

sgpyc commented 3 years ago

L216 is the evaluation loss and accuracy calculation. L224 has tf.metrics.accuracy and tf.metrics.mean on the masked_lm_accuracy and masked_lm_loss, tf.metrics.mean does a cross replica mean, and the denominator calculation is proper w.r.t. to the weights given to the tf.metrics.mean function. The end result is that, for evaluation, which batch a given sequence falls to won't affect the eval loss and accuracy.

For training, the masked_lm loss is at L276. The training loss calculation is indeed local, and which batch a given sequence falls into does affect the local loss. There is no difference between the reference and the [research model] (https://github.com/google-research/bert/blob/master/run_pretraining.py#L240) that the BERT and LAMB papers used w.r.t. this.

qpjaada commented 3 years ago

Thanks @sgpyc for the comments.

  1. We agree that the reference and the original BERT source code (for the loss) are the same.
  2. We were just confused by the fact (as you mention) that the contribution to local loss from a given sequence depends on which batch it is assigned to. This is somewhat unusual for the SGD-family of optimizers. We couldn't find any motivating reason for this loss-formulation in the original BERT paper. If you are aware of any such motivation, we would be interested in learning more.

Anyway, since the preference is to stick as close to original BERT loss formulation, we can keep it that way. Closing this issue as resolved.