jxhe / efficient-knnlm

Pytorch implementation of paper "Efficient Nearest Neighbor Language Models" (EMNLP 2021)
MIT License
71 stars 6 forks source link

Question on get_knn_log_prob #3

Closed jiaqing23 closed 1 year ago

jiaqing23 commented 1 year ago

Hi, I am reading the code. In knnlm.py, there is a line (https://github.com/jxhe/efficient-knnlm/blob/main/fairseq/knnlm.py#L267):

index_mask = torch.eq(torch.from_numpy(self.vals[knns]).long().cuda().squeeze(-1), tgt[knn_mask].unsqueeze(-1)).float()

May I know what is the purpose of this line? Is the tgt means the prediction target tokens? If so, why is the target available during testing?

Thanks!

jiaqing23 commented 1 year ago

Closing this question. Found the answer after reading the code again.

This is to calculate what is the knn probability of the ground truth label, and it is only used for validation. The metric used is perplexity so this is enough for it.