psunlpgroup / CONTaiNER

Code for ACL 2022 paper "CONTaiNER: Few-Shot Named Entity Recognition via Contrastive Learning"
MIT License
112 stars 18 forks source link

What is the purpose of the cnts variable in utils.py? #15

Closed xxmyf closed 1 year ago

xxmyf commented 1 year ago

def nt_xent(loss, num, denom, temperature = 1):

loss = torch.exp(loss/temperature)
cnts = torch.sum(num, dim = 1)
loss_num = torch.sum(loss * num, dim = 1)
loss_denom = torch.sum(loss * denom, dim = 1)
# sanity check
nonzero_indexes = torch.where(cnts > 0)
loss_num, loss_denom, cnts = loss_num[nonzero_indexes], loss_denom[nonzero_indexes], cnts[nonzero_indexes]

loss_final = -torch.log2(loss_num) + torch.log2(loss_denom) + torch.log2(cnts)
return loss_final