UKPLab / sentence-transformers

Multilingual Sentence & Image Embeddings with BERT
https://www.SBERT.net
Apache License 2.0
14.74k stars 2.43k forks source link

Using DataLoader in v3? #2707

Open AleksanderMarek opened 2 months ago

AleksanderMarek commented 2 months ago

Hi,

I wanted to embark on my first adventure of fine tuning an embedding model, having seen that v3 was released recently. I've started looking into the documentation and found that NoDuplicatesDataLoader is a suitable DataLoader for MNRL loss, particularly when having multiple positive labels for the same anchor, i.e. A1, P1; A1, P2... But I really can't find a way of using this DataLoader within SentenceTransformerTrainer training paradigm. From what I can tell in v2 DataLoader would be passed to SentenceTransformer.fit(), but SentenceTransformerTrainer takes only a Dataset and generates a DataLoader internally (with no control over the type of the DataLoader).

Any tips how to incorporate multiple positive example in the training process would be highly appreciated!

tomaarsen commented 2 months ago

Hello!

Indeed, the Trainer accepts a Dataset and creates a BatchSampler internally. Right now, there's 3 options:

Sadly, I think none of these will work for you. The closest is BatchSamplers.NO_DUPLICATES, but that one only prevents exact duplicates, i.e. (anchor, positive) must be exactly the same. Perhaps a natural extention is to add a batch_sampler_kwargs option to the SentenceTransformerTrainingArguments, and extend the NoDuplicatesBatchSampler with the possibility to supply a list of columns for which you want to make sure there's no duplicates.

Then, in your case, we could use e.g.:

args = SentenceTransformerTrainingArguments(
    ...,
    batch_sampler="no_duplicates",
    batch_sampler_kwargs={"columns": ["anchor"]},
)

And then the code will only check for duplicates in the "anchor" column.

In the meantime, I've implemented get_batch_sampler and get_multi_dataset_batch_sampler as separate methods on purpose, allowing you to easily override them. E.g.:

from sentence_transformers import SentenceTransformerTrainer
from sentence_transformers.sampler import NoDuplicatesBatchSampler

class MyNoDuplicatesBatchSampler(NoDuplicatesBatchSampler):
    def __iter__(self):
        """
        Iterate over the remaining non-yielded indices. For each index, check if the sample values are already in the
        batch. If not, add the sample values to the batch keep going until the batch is full. If the batch is full, yield
        the batch indices and continue with the next batch.
        """
        if self.generator and self.seed:
            self.generator.manual_seed(self.seed + self.epoch)

        remaining_indices = set(torch.randperm(len(self.dataset), generator=self.generator).tolist())
        while remaining_indices:
            batch_values = set()
            batch_indices = []
            for index in remaining_indices:
                sample_values = {self.dataset[index]["anchor"]} # <- This line differs
                if sample_values & batch_values:
                    continue

                batch_indices.append(index)
                if len(batch_indices) == self.batch_size:
                    yield batch_indices
                    break

                batch_values.update(sample_values)

            else:
                # NOTE: some indices might still have been ignored here
                if not self.drop_last:
                    yield batch_indices

            remaining_indices -= set(batch_indices)

class MySentenceTransformerTrainer(SentenceTransformerTrainer):
    def get_batch_sampler(
        self,
        dataset: "Dataset",
        batch_size: int,
        drop_last: bool,
        valid_label_columns: Optional[List[str]] = None,
        generator: Optional[torch.Generator] = None,
    ) -> BatchSampler:
        if self.args.batch_sampler == BatchSamplers.NO_DUPLICATES:
            return MyNoDuplicatesBatchSampler(
                dataset=dataset,
                batch_size=batch_size,
                drop_last=drop_last,
                valid_label_columns=valid_label_columns,
                generator=generator,
            )

Note, the script above is untested.