Closed xxmyf closed 1 year ago
def nt_xent(loss, num, denom, temperature = 1):
loss = torch.exp(loss/temperature) cnts = torch.sum(num, dim = 1) loss_num = torch.sum(loss * num, dim = 1) loss_denom = torch.sum(loss * denom, dim = 1) # sanity check nonzero_indexes = torch.where(cnts > 0) loss_num, loss_denom, cnts = loss_num[nonzero_indexes], loss_denom[nonzero_indexes], cnts[nonzero_indexes] loss_final = -torch.log2(loss_num) + torch.log2(loss_denom) + torch.log2(cnts) return loss_final
def nt_xent(loss, num, denom, temperature = 1):