codertimo / BERT-pytorch

Google AI 2018 BERT pytorch implementation
Apache License 2.0
6.09k stars 1.29k forks source link

why specify `ignore_index=0` in the NLLLoss function in BERTTrainer? #98

Open Jasmine969 opened 2 years ago

Jasmine969 commented 2 years ago

trainer/pretrain.py

class BERTTrainer:
    def __init__(self, ...):
        ... 
        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = nn.NLLLoss(ignore_index=0)
        ...

I cannot understand why ignore index=0 is specified when calculating NLLLoss. If the ground truth of is_next is False (label = 0) in terms of the NSP task but BERT predicts True, then NLLLoss will be 0 (or nan)... so what's the aim of ignore_index = 0 ???

====================

Well, I've found that ignore_index = 0 is useful to the MLM task, but I still can't agree the NSP task should share the same NLLLoss with MLM.

MingchangLi commented 1 year ago

see #32 change self.criterion = nn.NLLLoss(ignore_index=0) to self.criterion = nn.NLLLoss()