UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.23k stars 2.47k forks source link

Question about MultipleNegativesRankingLoss and gradient accumulation steps #2916

Open DogitoErgoSum opened 2 months ago

DogitoErgoSum commented 2 months ago

How does the MultipleNegativesRankingLoss function when used with gradient accumulation steps?

According to the docs

For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.

Are the negatives from other steps used (during accumulation), or are only the negatives from the samples in the current batch (per_device_train_batch_size) used?

tomaarsen commented 2 months ago

Hello!

Great question! It's the latter, only the negatives from the samples in the current batch, i.e. per_device_train_batch_size samples, are used. Gradient accumulation does not result in better performance due to larger batch sizes for the in-batch negative losses.

For that, I would recommend using the Cached losses, such as CachedMultipleNegativesRankingLoss. In short, this loss is equivalent to MultipleNegativesRankingLoss, but cleverly uses caches and mini-batches to reach very high per_device_train_batch_size with constant memory usage based on the mini-batch size. For example, you can use CachedMultipleNegativesRankingLoss with a per_device_train_batch_size of 4096 with a mini-batch size of 64, and you'll get the same memory usage as MultipleNegativesRankingLoss with a per_device_train_batch_size of 64. You'll get a stronger training signal, at the cost of some training speed overhead (about 20% usually).

DogitoErgoSum commented 2 months ago

Thank you for the fast answer! I will try the cached version.

DogitoErgoSum commented 2 months ago

Last question, how does BatchSamplers.NO_DUPLICATES work with gradient accumulation steps?

tomaarsen commented 2 months ago

The "no duplicates" works on a per-batch level, so with e.g. a per_device_train_batch_size of 16 and a gradient accumulation steps of 4, then you'll get 4 batches per loss propagation where each batch does not have duplicate samples in them. With other words, no issues due to duplicates. There's no "cross-batch communication" when doing gradient accumulation other than that the losses from each batch get added together.

If you instead use CachedMNRL with no duplicates with e.g. a per_device_train_batch_size of 64 and a mini-batch size of 16, then you will get just 1 batch per loss propagation. Duplicates are also avoided in this batch, so there's no issues here either.

For context for those who don't know why not having "no duplicates" can be problematic for in-batch negative losses: if you have e.g. question-answer pairs, and answer Y for an unrelated question Y is the same as answer X for question X, then that answer will both be considered a positive and a negative, negating the usefulness of this sample.

Does that clear it up?

DogitoErgoSum commented 2 months ago

Does that clear it up?

Yes. This raises another question, does the "no duplicates" checks for repeated anchors or positives?

DogitoErgoSum commented 2 months ago

And suppose i use per_device_train_batch_size= size of training data. Will the "no duplicates" delete duplicates or divide the batch_size into N batches where there are no duplicates in each batch?

DogitoErgoSum commented 2 months ago

Sorry for the question spam. If we use triplets instead of anchor-positive pairs, does the following still happen?

For each a_i, it uses all other p_j as negative samples, i.e., for a_i, we have 1 positive example (p_i) and n-1 negative examples (p_j). It then minimizes the negative log-likehood for softmax normalized scores.

pesuchin commented 2 months ago

Hello!

The following code section ensures that there are no duplicates among anchor, positive, and negative: https://github.com/UKPLab/sentence-transformers/blob/0a32ec8445ef46b2b5d4f81af4931e293d42623f/sentence_transformers/sampler.py#L146-L151

When using anchor, positive, and negative instead of anchor-positive pairs, sample_values would be {anchor, positive, negative}, and a duplication check is performed with sample_values & batch_values. Therefore, if any of the texts in the batch are duplicates, they will be resampled.

To illustrate with a specific example, in the following case, sample_values & batch_values would result in {"positive1"}, indicating a duplication, so resampling would occur:

batch_values = {"anchor1", "positive1", "negative1", "anchor2", "positive2", "negative2"}
sample_values = {"anchor3", "positive1", "negative3"}

In this way, it guarantees that there are no duplicates for all of anchor, positive, and negative samples. Therefore, I believe the answer to the following question would be Yes:

This raises another question, does the "no duplicates" checks for repeated anchors or positives?

I also think the answer to the following question would be Yes:

If we use triplets instead of anchor-positive pairs, does the following still happen?