sebastian-hofstaetter / matchmaker

Training & evaluation library for text-based neural re-ranking and dense retrieval models built with PyTorch
https://neural-ir-explorer.ec.tuwien.ac.at/
Apache License 2.0
261 stars 30 forks source link

typo in KLDivTeacherList? #18

Open jihyukkim-nlp opened 2 years ago

jihyukkim-nlp commented 2 years ago

Thank you for sharing codes. I wonder whether there is a typo in the implementation of KLDivTeacherList class.

The implementation is

class KLDivTeacherList(nn.Module):
    def __init__(self):
        super(KLDivTeacherList, self).__init__()
        self.kl = torch.nn.KLDivLoss(reduction="batchmean")
    def forward(self, scores, labels):
        loss = self.kl(scores.softmax(-1),labels.softmax(-1)) # is this a typo?
        return loss

However, PyTorch documentation for KLDivLoss (https://pytorch.org/docs/1.9.1/generated/torch.nn.KLDivLoss.html) says

So, from what I understand, the forward function should be

def forward(self, scores, labels):
    # loss = self.kl(scores.softmax(-1),labels.softmax(-1)) # is this a typo?
    # before: softmax of scores

    # after : log-softmax of scores
    loss = self.kl(torch.nn.functional.log_softmax(scores, dim=-1), torch.nn.functional.softmax(labels, dim=-1))
    return loss

I wonder if this is a typo or if I'm missing something. Thanks in advance.