vdogmcgee / SimCSE-Chinese-Pytorch

SimCSE在中文上的复现,有监督+无监督
MIT License
265 stars 48 forks source link

感谢楼主,按照原始论文的公式,写了有监督的损失函数 #13

Closed yuanphoenix closed 1 year ago

yuanphoenix commented 1 year ago
def another_simcse_sup_loss_v2(y_pred):
    """
     y_pred (tensor): bert的输出, [batch_size * 3, 768]
    """
    temperature = 0.05
    y_label = torch.arange(0, y_pred.shape[0])
    y_other = torch.where(y_label % 3 != 0)

    sim = F.cosine_similarity(y_pred.unsqueeze(1), y_pred.unsqueeze(0), dim=-1)
    sim = sim - torch.eye(y_pred.shape[0], device=DEVICE) * 1e12
    loss = 0.0

    for index in range(0, sim.shape[0], 3):
        numerator = torch.exp(sim[index][index + 1] / temperature)
        denominator = torch.sum(torch.exp(sim[index][y_other] / temperature))
        loss += -1 * torch.log(numerator / denominator)
    return loss

最终的结果和楼主的大致符合: image

yuanphoenix commented 1 year ago

我有一个疑问,像这种奇奇怪怪的公式,大佬们是怎么找出来的?

vdogmcgee commented 1 year ago

感觉上是用公式去达到对比学习的目的