Closed yuanphoenix closed 1 year ago
def another_simcse_sup_loss_v2(y_pred): """ y_pred (tensor): bert的输出, [batch_size * 3, 768] """ temperature = 0.05 y_label = torch.arange(0, y_pred.shape[0]) y_other = torch.where(y_label % 3 != 0) sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1) sim = sim - torch.eye(y_pred.shape[0], device=DEVICE) * 1e12 loss = 0.0 for index in range(0, sim.shape[0], 3): numerator = torch.exp(sim[index][index + 1] / temperature) denominator = torch.sum(torch.exp(sim[index][y_other] / temperature)) loss += -1 * torch.log(numerator / denominator) return loss
最终的结果和楼主的大致符合:
我有一个疑问,像这种奇奇怪怪的公式,大佬们是怎么找出来的?
感觉上是用公式去达到对比学习的目的
最终的结果和楼主的大致符合: