UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
Apache License 2.0
14.95k stars 2.44k forks source link

Multi-dataset: An error occurs after a few steps #2816

Closed lanx7 closed 1 month ago

lanx7 commented 2 months ago

I am trying to train sentence embedding model with a set of triples and a set of pairs, and their respective sizes are 1499 and 566. Following the multi-dataset guide, I have constructed the DatasetDict as below, and when I start the training, an error occurs after completing 7 out of 75 steps. Do you have any hints that might help?

=== Code:

 train_loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=args.mini_batch)

train_dataset = datasets.Dataset.from_list(data)
train_dataset2 = datasets.Dataset.from_list(data2)

train_data_all = {
    'train1': train_dataset, 
    'train2': train_dataset2, 
losses = {
    'train1': train_loss,
    'train2': train_loss
 trainer = SentenceTransformerTrainer(
    evaluator=evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: np.mean(scores)),

==== Error Message ===

  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/sentence_transformers/", line 211, in __iter__
    yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/dev/etest/etest/etrain/", line 240, in <module>
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/transformers/", line 1932, in train
    return inner_training_loop(
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/transformers/", line 2230, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/accelerate/", line 464, in __iter__
    next_batch = next(dataloader_iter)
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/torch/utils/data/", line 634, in __next__
    data = self._next_data()
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/torch/utils/data/", line 677, in _next_data
    index = self._next_index()  # may raise StopIteration
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/torch/utils/data/", line 624, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
RuntimeError: generator raised StopIteration
tomaarsen commented 2 months ago


I've tried to reproduce this with this script:

import datasets
from sentence_transformers import (
model = SentenceTransformer("all-MiniLM-L6-v2")

train_loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=32)

data = [{"anchor": "bla", "positive": "bla"}] * 1499
data2 = [{"anchor": "bla", "negative": "bla"}] * 566

train_dataset = datasets.Dataset.from_list(data)
train_dataset2 = datasets.Dataset.from_list(data2)

train_data_all = {
    'train1': train_dataset, 
    'train2': train_dataset2, 
losses = {
    'train1': train_loss,
    'train2': train_loss
args = SentenceTransformerTrainingArguments(
trainer = SentenceTransformerTrainer(

But this one seems to run well. Could you share some of your args options? In particular, per_device_train_batch_size, dataloader_drop_last, dataloader_num_workers, dataloader_persistent_workers, dataloader_pin_memory, dataloader_prefetch_factor. If you don't specify some of these, you can just let me know that, so I know that you're still on the defaults.

Apologies for the inconvenience. To me this sounds like a bug on the side of my samplers.

lanx7 commented 2 months ago

Hi Aarsen, Thank you for the prompt response. Attached are the parameters used. dataloader_drop_last, dataloader_num_workers, dataloader_persistent_workers, dataloader_pin_memory, and dataloader_prefetch_factor are not set.

== eval_steps = int(len_train*args.eval_ratio/train_batch_size)

ecb = EarlyStoppingCallback(early_stopping_patience=5)
mcb = MyCallback(path=qp_model_save_path, name="result")

args = SentenceTransformerTrainingArguments(
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    # Optional tracking/debugging parameters:
    run_name="sbert_sup_train",  # Will be used in W&B if `wandb` is installed
lanx7 commented 2 months ago

I have tested your code and found that 'batch_sampler=BatchSamplers.NO_DUPLICATES' is causing the problem. When I turned on that option, an error occurred at step 0, but when I turned it off, no error occurred."

tomaarsen commented 2 months ago

Thanks for sharing! I think I should be able to try and reproduce something with that. I suspect that the no duplicates batch sampler has enough samples for n batches, but due to duplicate samples it ended up only being able to return e.g. n-1 batches. This would result in a StopIteration like you experienced.

I'll try to think of a solid solution for this.

lanx7 commented 1 month ago

@tomaarsen Hi Aarsen, Do you have any update on this issue?

tomaarsen commented 1 month ago


Apologies for the delay, I've been recovering from a surgery this last month. I haven't yet discovered a solid solution here. The core of the problem is that a DataLoader needs its batch sampler to return a total batch count, i.e the __len__ of the batch sampler. For the NoDuplicatesBatchSampler I assume that no samples will be discarded (which is a bit naive), so then the reported batch count is 1 higher than the actual batch count that the sampler can produce.

On the other hand, if I pre-compute the batches so I don't have to estimate the batch count, then I have to load all of the data into memory right at the start of training. This is also a problem.

A potential fix is to:

  1. Provide a batch count of n by assuming that no samples are discarded
  2. If the actual batch count is < n because samples had to be discarded (as otherwise there's duplicates in a batch), then we just give the first batch again.

In practice, the mismatch between actual and assumed batch count will likely be 1 at maximum, so then 1 sample per epoch will be duplicate data. This is technically not ideal as conceptually 1 epoch shouldn't include the same data multiple times, but it's better than the alternative of a crash.

I think I will try to experiment with this solution.

tomaarsen commented 1 month ago

After some more experimentation: the issue also goes away if dataloader_drop_last=False (which is the default).

tomaarsen commented 1 month ago

I've discovered the root cause, and I've implemented a different fix than the one described above in #2877. Feel free to have a look there.

lanx7 commented 1 month ago

@tomaarsen Thank you for the update.

I'm very glad to hear that you've found a solution. I'll check up #2877.

Thank you.