UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
14.95k stars 2.44k forks source link

Multi-dataset: An error occurs after a few steps #2816

Closed lanx7 closed 1 month ago

lanx7 commented 2 months ago

I am trying to train sentence embedding model with a set of triples and a set of pairs, and their respective sizes are 1499 and 566. Following the multi-dataset guide, I have constructed the DatasetDict as below, and when I start the training, an error occurs after completing 7 out of 75 steps. Do you have any hints that might help?

=== Code:

 train_loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=args.mini_batch)

train_dataset = datasets.Dataset.from_list(data)
train_dataset2 = datasets.Dataset.from_list(data2)

train_data_all = {
    'train1': train_dataset, 
    'train2': train_dataset2, 
}
losses = {
    'train1': train_loss,
    'train2': train_loss
}
....
 trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_data_all,
    #eval_dataset=dataset,
    loss=losses,
    evaluator=evaluation.SequentialEvaluator(evaluators, main_score_function=lambda scores: np.mean(scores)),
    callbacks=[ecb,mcb]
)

==== Error Message ===

  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/sentence_transformers/sampler.py", line 211, in __iter__
    yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])]
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
StopIteration

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspace/dev/etest/etest/etrain/sbert3_train_supervised.py", line 240, in <module>
    trainer.train()
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/transformers/trainer.py", line 1932, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/transformers/trainer.py", line 2230, in _inner_training_loop
    for step, inputs in enumerate(epoch_iterator):
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/accelerate/data_loader.py", line 464, in __iter__
    next_batch = next(dataloader_iter)
                 ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 634, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 677, in _next_data
    index = self._next_index()  # may raise StopIteration
            ^^^^^^^^^^^^^^^^^^
  File "/workspace/program/miniconda/envs/etest3/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 624, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: generator raised StopIteration
tomaarsen commented 2 months ago

Hello!

I've tried to reproduce this with this script:

import datasets
from sentence_transformers import (
    SentenceTransformerTrainingArguments,
    SentenceTransformer,
    SentenceTransformerTrainer,
    losses,
)
model = SentenceTransformer("all-MiniLM-L6-v2")

train_loss = losses.CachedMultipleNegativesRankingLoss(model, mini_batch_size=32)

data = [{"anchor": "bla", "positive": "bla"}] * 1499
data2 = [{"anchor": "bla", "negative": "bla"}] * 566

train_dataset = datasets.Dataset.from_list(data)
train_dataset2 = datasets.Dataset.from_list(data2)

train_data_all = {
    'train1': train_dataset, 
    'train2': train_dataset2, 
}
losses = {
    'train1': train_loss,
    'train2': train_loss
}
args = SentenceTransformerTrainingArguments(
    "tmp_trainer",
    per_device_train_batch_size=64,
)
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_data_all,
    loss=losses,
)
trainer.train()

But this one seems to run well. Could you share some of your args options? In particular, per_device_train_batch_size, dataloader_drop_last, dataloader_num_workers, dataloader_persistent_workers, dataloader_pin_memory, dataloader_prefetch_factor. If you don't specify some of these, you can just let me know that, so I know that you're still on the defaults.

Apologies for the inconvenience. To me this sounds like a bug on the side of my samplers.

lanx7 commented 2 months ago

Hi Aarsen, Thank you for the prompt response. Attached are the parameters used. dataloader_drop_last, dataloader_num_workers, dataloader_persistent_workers, dataloader_pin_memory, and dataloader_prefetch_factor are not set.

== eval_steps = int(len_train*args.eval_ratio/train_batch_size)

ecb = EarlyStoppingCallback(early_stopping_patience=5)
mcb = MyCallback(path=qp_model_save_path, name="result")

args = SentenceTransformerTrainingArguments(
    output_dir=qp_model_save_path,
    num_train_epochs=5
    per_device_train_batch_size=128,
    learning_rate=1e-05
    warmup_ratio=0.1,
    fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES, 
    multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=eval_steps,
    save_strategy="steps",
    save_steps=eval_steps,
    save_total_limit=3,
    logging_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="eval_sequential_score",
    run_name="sbert_sup_train",  # Will be used in W&B if `wandb` is installed
    dataloader_drop_last=True, 
)
lanx7 commented 2 months ago

I have tested your code and found that 'batch_sampler=BatchSamplers.NO_DUPLICATES' is causing the problem. When I turned on that option, an error occurred at step 0, but when I turned it off, no error occurred."

tomaarsen commented 2 months ago

Thanks for sharing! I think I should be able to try and reproduce something with that. I suspect that the no duplicates batch sampler has enough samples for n batches, but due to duplicate samples it ended up only being able to return e.g. n-1 batches. This would result in a StopIteration like you experienced.

I'll try to think of a solid solution for this.

lanx7 commented 1 month ago

@tomaarsen Hi Aarsen, Do you have any update on this issue?

tomaarsen commented 1 month ago

Hello!

Apologies for the delay, I've been recovering from a surgery this last month. I haven't yet discovered a solid solution here. The core of the problem is that a DataLoader needs its batch sampler to return a total batch count, i.e the __len__ of the batch sampler. For the NoDuplicatesBatchSampler I assume that no samples will be discarded (which is a bit naive), so then the reported batch count is 1 higher than the actual batch count that the sampler can produce.

On the other hand, if I pre-compute the batches so I don't have to estimate the batch count, then I have to load all of the data into memory right at the start of training. This is also a problem.

A potential fix is to:

  1. Provide a batch count of n by assuming that no samples are discarded
  2. If the actual batch count is < n because samples had to be discarded (as otherwise there's duplicates in a batch), then we just give the first batch again.

In practice, the mismatch between actual and assumed batch count will likely be 1 at maximum, so then 1 sample per epoch will be duplicate data. This is technically not ideal as conceptually 1 epoch shouldn't include the same data multiple times, but it's better than the alternative of a crash.

I think I will try to experiment with this solution.

tomaarsen commented 1 month ago

After some more experimentation: the issue also goes away if dataloader_drop_last=False (which is the default).

tomaarsen commented 1 month ago

I've discovered the root cause, and I've implemented a different fix than the one described above in #2877. Feel free to have a look there.

lanx7 commented 1 month ago

@tomaarsen Thank you for the update.

I'm very glad to hear that you've found a solution. I'll check up #2877.

Thank you.