sl-93 / SUPERVISED-CONTRASTIVE-LEARNING-FOR-PRE-TRAINED-LANGUAGE-MODEL-FINE-TUNING

in this project, I've implemented the Facebook paper about fine tuning RoBERTa with contrastive loss.
56 stars 5 forks source link

Can gradients in CL-loss “contrastive_loss(tem, hiden_state.cpu().detach().numpy(), b_labels)” be updated? #2

Open zhaishengfang opened 2 years ago

zhaishengfang commented 2 years ago

# Compute loss if scl == True: cross_loss = loss_fn(logits, b_labels) contrastive_l = contrastive_loss(tem, hiden_state.cpu().detach().numpy(), b_labels) loss = (lam * contrastive_l) + (1 - lam) * (cross_loss) if scl == False: loss = loss_fn(logits, b_labels)

when calculating contrastive_loss, you use "detach". Can gradients be updated?

zhaoyin214 commented 2 years ago

i don't think so, it was calculated with numpy. i am absolutely astonished

mrcabbage972 commented 2 years ago

@zhaishengfang @zhaoyin214 Here's a version of the contrastive loss function that's implemented in Pytorch and therefore can update the gradients. It's a line-by-line port of the existing code, so obviously the nested for loops make it very inefficient when running on GPU.

def sim_matrix(a, b, eps=1e-8):
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.clamp(a_n, min=eps)
    b_norm = b / torch.clamp(b_n, min=eps)
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt

def contrastive_loss(temp, embedding, label):
    """calculate the contrastive loss
    """
    # cosine similarity between embeddings
    cosine_sim = sim_matrix(embedding, embedding)
    n = cosine_sim.shape[0]
    dis = cosine_sim.masked_select(~torch.eye(n, dtype=bool)).view(n, n - 1)

    # apply temperature to elements
    dis = dis / temp
    cosine_sim = cosine_sim / temp
    # apply exp to elements
    dis = torch.exp(dis)
    cosine_sim = torch.exp(cosine_sim)

    # calculate row sum
    row_sum = torch.sum(dis, -1)

    unique_labels, inverse_indices, unique_label_counts = torch.unique(label, sorted=False, return_inverse=True, return_counts=True)
    # calculate outer sum
    contrastive_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
    for i in range(n):
        n_i = unique_label_counts[inverse_indices[i]] - 1
        inner_sum = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
        # calculate inner sum
        for j in range(n):
            if label[i] == label[j] and i != j:
                inner_sum = inner_sum + torch.log(cosine_sim[i][j] / row_sum[i])
        if n_i != 0:
            contrastive_loss += (inner_sum / (-n_i))
    return contrastive_loss