Open lyconghk opened 10 months ago
Agree with you @lyconghk . Have you come up with any better solution to apply mlm_weights_X
in mlm_l
calculation?
The weight
parameter of PyTorch CrossEntropyLoss
does not seem to support mlm_weights_X
in the way that the MXNet does. I guess that is why the PyTorch version of _get_batch_loss_bert
calculate mlm_l
in this way. It tries to reduce the impact of padded tokens to mlm_l
, but it does not use mlm_weights_X
in an correct way.
How about just use the package torch.nn import functional
to calculate the two cross entropy loss of mlm and nsp?
And remove the input parameter loss
in the function _get_batch_loss_ber
.
from torch.nn import functional as F
mlm_l = F.cross_entropy(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1), reduction='none')
nsp_l = F.cross_entropy(nsp_Y_hat, nsp_Y, reduction='mean')
In my opinion, the BERT pretrain batch loss in the function _get_batch_loss_bert is not correct. The following is the detail:
The CrossEntropyLoss is initialized with default reduction 'mean',
loss = nn.CrossEntropyLoss()
In the function _get_batch_loss_bert, mlm_loss and nsp_loss used the same input instance loss for computation.mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) *mlm_weights_X.reshape(-1, 1)
Since the reduction='mean', the resultant tensor of 'loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) ' is a scalar tensor, it leads a problem for mlm loss computation by positionwise product with the input tensor mlm_weights_X.