JonasGeiping / cramming

Cramming the training of a (BERT-type) language model into limited compute.
MIT License
1.3k stars 100 forks source link

Question about sparse token prediction #33

Closed leo-du closed 1 year ago

leo-du commented 1 year ago

Hi Jonas,

Thanks for sharing the great work! I have a small question about the paper.

Both your paper and Izsak et al. referred to Roberta for something called "sparse token prediction", which I couldn't find in the Roberta paper. From your code, it appears that "sparse token prediction" just means that you are only calculating the loss from the positions that's masked. It seems that this should be the default setting for training an MLM (and appears to be the case in Bert's code. The situation where you turn off this sparse prediction doesn't quite make sense -- why would one want to predict the unmasked tokens? Am I missing something obvious here?

Thanks for any help!

JonasGeiping commented 1 year ago

Hi, I totally missed this.

Yes, this is strictly a standard performance improvement (that maybe was not discussed enough in the original BERT paper, which is where the confusion comes from), and there is no downside to applying it (aside from dynamic tensor shapes in some implementations).

Regarding your question, even if sparse prediction is turned off, there is still no loss computed on the unmasked tokens. There would just be more unused work as logits are computed than are then masked during the computation of the cross entropy loss.