xlang-ai / instructor-embedding

[ACL 2023] One Embedder, Any Task: Instruction-Finetuned Text Embeddings
Apache License 2.0
1.78k stars 131 forks source link

Cosine Similarity of Anchor and Negative is not taken into consideration in Loss calculation #94

Closed ashokrajab closed 8 months ago

ashokrajab commented 8 months ago

In train.py def compute_loss() the cosine similarity value of query vs positive is first appended to all_scores. Following that cosine similarity of query vs negative is appended to all_scores. This is done for every triplet in the batch. Finally, the all_scores shape will be (batch_size x batch_size+1). Then when calculating the loss = nn.CrossEntropyLoss()(all_scores, labels) the labels are entirely zeros.

This means only the cosine similarity of index 0 in all_scores will be taken into consideration of loss computation (in other words loss = - log(all_scores[0])). The cosine similarity between query and negatives does not contribute to loss computation.

In my opinion, the cosine similarity value should have been passed through a sigmoid function and then Binary Cross Entropy Loss function should have been used (based on whether the similarty is for query vs pos or query vs neg).

@Harry-hash @hongjin-su Could you please shed some light on this?

hongjin-su commented 8 months ago

Hi Thanks a lot for your interest in the INSTRUCTOR!

I think nn.CrossEntropyLoss has integrated the sigmoid function and considered cosine similarity between query and negatives, according to the documentation in: https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html.

To further verify, you may try the following script and see that loss and loss1 are always the same:

import torch

l = torch.nn.CrossEntropyLoss()
logits = torch.randn(4,5)
labels = torch.zeros(logits.size(0)).long()

loss = torch.nn.CrossEntropyLoss()(logits, labels)
print(loss)

loss1 = 0
for logit in logits:
    loss1 += torch.log(torch.exp(logit[0])/torch.sum(torch.exp(logit)))
loss1 /= logits.size(0)
loss1 = -loss1
print(loss1)