UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.19k stars 2.47k forks source link

Handling Explicit Negative Examples in `MultipleNegativesRankingLoss` #2260

Open Alec-Stashevsky opened 1 year ago

Alec-Stashevsky commented 1 year ago

Dear maintainers,

Thank you for the great work on the Sentence Transformers library.

I am writing to ask for clarification regarding the MultipleNegativesRankingLoss class. In the class documentation, it's mentioned that we can provide explicit hard negatives per anchor-positive pair by structuring the data like (a_1, p_1, n_1), (a_2, p_2, n_2) etc.

However, after inspecting the source code, specifically the forward method, I couldn't find an explicit mechanism that handles these additional negative examples:

def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
    reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
    embeddings_a = reps[0]
    embeddings_b = torch.cat(reps[1:])
    scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
    labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)  # Example a[i] should match with b[i]
    return self.cross_entropy_loss(scores, labels)

From what I understand, the method treats all sentence pairs in the batch as positive pairs (a_i, p_i), and all p_j (for i ≠ j) as negative examples. The explicit negative examples mentioned in the class documentation do not seem to be treated differently.

If explicit negative examples are intended to be supported, I'm thinking we might need to modify the forward method. For instance, something like:

def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
    reps = [self.model(sentence_feature)['sentence_embedding'] for sentence_feature in sentence_features]
    if len(sentence_features[0]['texts']) == 3:  # Check if triplets are provided
        embeddings_a = reps[::3]  # anchors
        positives = reps[1::3]  # positives
        negatives = reps[2::3]  # negatives
        embeddings_b = torch.cat(positives + negatives)
    else:
        embeddings_a = reps[::2]  # anchors
        embeddings_b = torch.cat(reps[1::2])  # positives
    scores = self.similarity_fct(embeddings_a, embeddings_b) * self.scale
    labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)  # Example a[i] should match with b[i]
    return self.cross_entropy_loss(scores, labels)

Can you please provide some insight on this? Is my understanding correct, and would the suggested modification be appropriate?

Any guidance or clarification would be greatly appreciated.

Best regards, Alec

aqx95 commented 1 year ago

The implementation is actually correct. embeddings_b = torch.cat(reps[1:]) concatenates both the positive and negative embeddings along dim=0. So your embeddings_a will have shape (bs, dims), and embeddings_b have shape (bs x 2, dims). After running the similarity_fct, your score vector will have a shape of (bs, 2 x bs) where column 0-63 represents the similarity score for each of the positive text in the batch in relation to the query, and column 63-127 represents the similarity score for each negative text in relation to the query. So now, for example in row 1, the only positive text should be in score[0][0] while the remaining columns are treated as negatives, essentially treating all p_0 (for i ≠ 0) as negative examples.. The ground truth labels generated will also corresponds to score[0][0].