Open zhaishengfang opened 2 years ago
i don't think so, it was calculated with numpy. i am absolutely astonished
@zhaishengfang @zhaoyin214 Here's a version of the contrastive loss function that's implemented in Pytorch and therefore can update the gradients. It's a line-by-line port of the existing code, so obviously the nested for loops make it very inefficient when running on GPU.
def sim_matrix(a, b, eps=1e-8):
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
a_norm = a / torch.clamp(a_n, min=eps)
b_norm = b / torch.clamp(b_n, min=eps)
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
return sim_mt
def contrastive_loss(temp, embedding, label):
"""calculate the contrastive loss
"""
# cosine similarity between embeddings
cosine_sim = sim_matrix(embedding, embedding)
n = cosine_sim.shape[0]
dis = cosine_sim.masked_select(~torch.eye(n, dtype=bool)).view(n, n - 1)
# apply temperature to elements
dis = dis / temp
cosine_sim = cosine_sim / temp
# apply exp to elements
dis = torch.exp(dis)
cosine_sim = torch.exp(cosine_sim)
# calculate row sum
row_sum = torch.sum(dis, -1)
unique_labels, inverse_indices, unique_label_counts = torch.unique(label, sorted=False, return_inverse=True, return_counts=True)
# calculate outer sum
contrastive_loss = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
for i in range(n):
n_i = unique_label_counts[inverse_indices[i]] - 1
inner_sum = torch.tensor(0, dtype=embedding.dtype, device=embedding.device)
# calculate inner sum
for j in range(n):
if label[i] == label[j] and i != j:
inner_sum = inner_sum + torch.log(cosine_sim[i][j] / row_sum[i])
if n_i != 0:
contrastive_loss += (inner_sum / (-n_i))
return contrastive_loss
# Compute loss if scl == True: cross_loss = loss_fn(logits, b_labels) contrastive_l = contrastive_loss(tem, hiden_state.cpu().detach().numpy(), b_labels) loss = (lam * contrastive_l) + (1 - lam) * (cross_loss) if scl == False: loss = loss_fn(logits, b_labels)
when calculating contrastive_loss, you use "detach". Can gradients be updated?