UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.31k stars 2.48k forks source link

Replicating all-MiniLM-L6-v2 with SentenceTransformerTrainer #2975

Open claeyzre opened 1 month ago

claeyzre commented 1 month ago

I came accross the training script of sentence-transformer model, all-MiniLM-L6-v2. I am trying to replicate it using the SentenceTransformerTrainer but few problems arose quickly:

One could think that the recent releases would easily solve this problem by creating a DatasetDict of datasets with losses affiliated BUT the datasets are extremely large so in my case I would go in creating IterableDataset and use the interleave function to mimic the "weight" argument of this script. Unfortunately the trainer doesn't take a dict of iterable datasets, only Dataset object. And I cannot create One big IterableDataset as there will be both pairs and triplets thus the loss would need to change w.r.t what's inside the batch (pair / triplet)

I end up doing exactly what done in this script and I was wondering if I missed something in the SentenceTransformerTrainer capabilities to handle both iterable datasets with multi tasks training.

Thank you very much,

tomaarsen commented 1 month ago

Hello!

Indeed, a DatasetDict of IterableDatasets isn't supported, because training with multiple datasets at once is implemented via a ConcatDict, but this requires datasets with known lengths.

As for the two losses, the script uses MultipleNegativesRankingLoss, once with pairs and once with triplets. Sadly, this doesn't help us too much, as we still can't create one IterableDataset because the datasets have different columns.

The model was finetuned with 100k steps using a batch size of 1024, so this ends up being 100M samples whereas the entire dataset is 1.17B samples. Also, datasets.Dataset does not load everything into memory. With other words, it might be possible to load the 100M samples (depending on your device) with a dictionary of Dataset (or DatasetDict):

🤗 Datasets uses Arrow for its local caching system. It allows datasets to be backed by an on-disk cache, which is memory-mapped for fast lookup. This architecture allows for large datasets to be used on machines with relatively small device memory.

Have you tried starting a training job with a really large dataset? Even if it's just some synthetic data, e.g. writing {"anchor": "What are pandas?", "positive": "The giant panda also known as the panda bear or simply panda, is a bear species endemic to China."} to a file 100M times over. If you load it with load_dataset("json", data_files="data.json"), it might not exceed your memory usage.

claeyzre commented 1 month ago

Indeed, a DatasetDict of IterableDatasets isn't supported, because training with multiple datasets at once is implemented via a ConcatDict, but this requires datasets with known lengths.

Interesting, would it be possible to use the number of lines specified in the data_config to overcome this limitation ?

The model was finetuned with 100k steps using a batch size of 1024, so this ends up being 100M samples whereas the entire dataset is 1.17B samples.

I didn't have this detail, this would mean I could load all the datasets once in a DatasetDict and fake the weight trick by doing a custom Sampler that follows the "weights" distribution, do you think that would be feasible ? Because even though the dataset is not a problem for RAM, I still want to replicate the "weight" aspect of this training.

Thank you very much

tomaarsen commented 1 month ago

Interesting, would it be possible to use the number of lines specified in the data_config to overcome this limitation ?

In theory, yes. You'd have to override the dataloader methods: https://github.com/UKPLab/sentence-transformers/blob/78553270abc74f44c1504db0e29f79591af6b697/sentence_transformers/trainer.py#L567-L657

I didn't have this detail, this would mean I could load all the datasets once in a DatasetDict and fake the weight trick by doing a custom Sampler that follows the "weights" distribution, do you think that would be feasible ?

Yes, I believe so. I would go this route personally. You should be able to override this: https://github.com/UKPLab/sentence-transformers/blob/78553270abc74f44c1504db0e29f79591af6b697/sentence_transformers/trainer.py#L530-L565 which is in charge of sampling batches from the ConcatDataset, a concatenation of your individual datasets. You also get a list of batch samplers for each of the individual datasets, so all you have to do is use some weight to determine which batch sampler to sample, and then add the dataset-specific offset in the larger ConcatDataset.

For example the RoundRobinSampler: https://github.com/UKPLab/sentence-transformers/blob/78553270abc74f44c1504db0e29f79591af6b697/sentence_transformers/sampler.py#L244-L258

This one first makes a sample_offsets list that acts as a mapping between dataset (id) and where this dataset starts in the ConcatDataset. Afterwards, it cyclically loops through each of the batch samplers until one is empty.

The ProportionalBatchSampler is a bit more complex, but it's rather close to what you're looking for: https://github.com/UKPLab/sentence-transformers/blob/78553270abc74f44c1504db0e29f79591af6b697/sentence_transformers/sampler.py#L288-L304

It also computes sample_offsets, but rather than cycling through the batch samplers until one is empty, we create a list of dataset IDs based on how many batches each dataset has. You can very easily update this to create a list based on num_batches * weight instead of just num_batches:

dataset_indices = [idx for idx, length in enumerate(num_batches) for _ in range(length) for _ in range(weight_mapping[idx])]

I think that's pretty much the only change you need.

Just like the original training script, this might be a smidge naive because some of the weights in there are like ~247 resulting in a massive number of steps per epoch, but it should be fine if you set max_steps.

Feel free to keep me in the loop!

claeyzre commented 1 month ago

which is in charge of sampling batches from the ConcatDataset, a concatenation of your individual datasets. You also get a list of batch samplers for each of the individual datasets, so all you have to do is use some weight to determine which batch sampler to sample, and then add the dataset-specific offset in the larger ConcatDataset.

Do you mean duplicating the batch samplers for each of the individual datasets ? Because if I do not, even though I iterate on the correct number of datasets indices in the Batch Sampler, I will reach a StopIteration from a batch sampler of an individual dataset and ending up with the wrong number of batches in the end. I guess a good solution would be to adapt the number of batch samplers before entering the WeightedBatchSampler class iter, so here: https://github.com/UKPLab/sentence-transformers/blob/78553270abc74f44c1504db0e29f79591af6b697/sentence_transformers/trainer.py#L607

Which is what I am doing but now I found myself with another problem: how can I now ensure that I will not have duplicates ? From what I understood in the original train_script.py, it's not really solved either no ?

Another issue is that I would have to introduce a new train_multi_dataset_batch_sampler in the args, as it makes sense to weight the datasets for training but not that much for testing or eval

tomaarsen commented 1 month ago

I will reach a StopIteration from a batch sampler of an individual dataset and ending up with the wrong number of batches in the end.

Oh, oops, you're totally right. My idea was a bit too simple. I think the easiest fix is to "reuse" the same batch sampler. E.g.:

batch_samplers = [
    iter(sampler)
    for idx, sampler in enumerate(self.batch_samplers)
    for _ in range(weights[idx])
]

You'll also have to change the sample_offsets, but now you should be able to sample based on weight.

As for the other issues:

  1. You will still have duplicates with this approach. The original script also had duplicates right? I think there's no way around that, except to "shrink" datasets based on weights rather than "duplicating" them. However, you can use args.batch_sampler=BatchSamplers.NO_DUPLICATES to make sure that there's no duplicates in the same batch (but not across batches in the same epoch). This is useful for MultipleNegativesRankingLoss, but not so much for reducing duplication that might lead to some samples being trained on multiple times.
  2. As for the testing/eval: good thinking, I didn't think about that either. Ideally, you would probably just apply a normal RoundRobinBatchSampler or ProportionalBatchSampler to the evaluation/test datasets, but then you'd have to know which dataset you're currently setting the batch sampler for in get_multi_dataset_batch_sampler, so it's a bit tricky. A hack is probably to see if self._train_dataloader has been defined (if no: WeightedBatchSampler, if yes: RoundRobinBatchSampler or ProportionalBatchSampler), but you can also override the get_..._dataloader methods to get there.