d2l-ai / d2l-en

Interactive deep learning book with multi-framework code, math, and discussions. Adopted at 500 universities from 70 countries including Stanford, MIT, Harvard, and Cambridge.
https://D2L.ai
Other
24.04k stars 4.37k forks source link

The mlm loss computation in the function _get_batch_loss_bert seems wrong in d2l pytorch code #2582

Open lyconghk opened 10 months ago

lyconghk commented 10 months ago

In my opinion, the BERT pretrain batch loss in the function _get_batch_loss_bert is not correct. The following is the detail:

The CrossEntropyLoss is initialized with default reduction 'mean', loss = nn.CrossEntropyLoss() In the function _get_batch_loss_bert, mlm_loss and nsp_loss used the same input instance loss for computation. mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *mlm_weights_X.reshape(-1, 1) Since the reduction='mean', the resultant tensor of 'loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) ' is a scalar tensor, it leads a problem for mlm loss computation by positionwise product with the input tensor mlm_weights_X.

gab-chen commented 9 months ago

Agree with you @lyconghk . Have you come up with any better solution to apply mlm_weights_X in mlm_l calculation?

The weight parameter of PyTorch CrossEntropyLoss does not seem to support mlm_weights_X in the way that the MXNet does. I guess that is why the PyTorch version of _get_batch_loss_bert calculate mlm_l in this way. It tries to reduce the impact of padded tokens to mlm_l, but it does not use mlm_weights_X in an correct way.

lyconghk commented 9 months ago

How about just use the package torch.nn import functional to calculate the two cross entropy loss of mlm and nsp? And remove the input parameter loss in the function _get_batch_loss_ber.

from torch.nn import functional as F

mlm_l = F.cross_entropy(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1), reduction='none')

nsp_l = F.cross_entropy(nsp_Y_hat, nsp_Y, reduction='mean')