UKPLab / sentence-transformers

State-of-the-Art Text Embeddings
https://www.sbert.net
Apache License 2.0
15.46k stars 2.5k forks source link

GroupByLabelBatchSampler #2782

Open waileong-leong opened 5 months ago

waileong-leong commented 5 months ago

In the GroupByLabelBatchSampler, what is the reason of doing (num_samples := len(sample_indices) // 2) to determine the groups?

https://github.com/UKPLab/sentence-transformers/blob/2dee8c22f618df3f34f55d406bf97aa88383728a/sentence_transformers/sampler.py#L68-L72

if a label has 4 sentence {"0" : [0,1,2,3]}, it will only contain [0,1] in the group

fpgmaas commented 5 months ago

Coincidentally, I made a small PR to improve that class today: https://github.com/UKPLab/sentence-transformers/pull/2788

Part of the PR is to remove these lines of code since I also was not sure why they were there.

tomaarsen commented 5 months ago

Hello!

Ai, this is indeed a bug. The goal of this snippet of code is to ensure that there is an even number of text values under groups. The __iter__ code relies on this (and that batch_size % 2 == 0) to ensure that each batch has 2+ texts from every label that occurs in the text. E.g. if you have 17 texts from one label and a batch size of 16, then __iter__ would yield the 16 and keep the leftover 1 to use in the next batch. This can cause issues with losses where there must be at least 2 texts for that label.

I believe the code should instead be:

        self.groups = {
            label: sample_indices[:num_samples]
            for label, sample_indices in groups.items()
            if (num_samples := len(sample_indices) // 2 * 2)
        }

This way we'll only throw away 1 sample at most per label.

fpgmaas commented 5 months ago

This

[...] to ensure that each batch has 2+ texts from every label that occurs in the text.

should be

[...] to ensure that each batch has 2+ texts from every label that occurs in the batch

Right? Not trying to be pedantic, just trying to confirm that my understanding of the sampler is correct :) Then I could incorporate the above in this PR. Although it can of course also be solved earlier in a separate PR that aims to fix just that.

tomaarsen commented 5 months ago

to ensure that each batch has 2+ texts from every label that occurs in the batch

Yes. That should be correct, indeed. The dataset can have labels and corresponding texts that do not get used in the batch, that should be fine. I'd love to incorporate the groups fix that I described above in your PR if possible, as I quite like the other changes that you're proposing @fpgmaas

waileong-leong commented 5 months ago

@tomaarsen whats the plan or your view on batchSampler implementation? currently it is restricted to only the 3 type and allow minimal customisation, i have monkey patched a custom sampler to replicate the sentencelabeldataset for supervise contrastive loss

dadamson commented 5 months ago

I came to the Issues tab to call out these two bugs (//2 and the hard-coded 'label' column) while cobbling together my own subclass (and, unfortunately necessarily, a subclass of SentenceTransformerTrainer to override get_batch_sampler()) -- thanks for beating me to the punch!

The one extra feature I need for the SentenceTransformerTrainer to use GroupByLabelBatchSampler once this is bugfixed is the ability to set the sampler's valid_label_columns from the TrainingArguments -- right now SentenceTransformerTrainer takes it from the DataCollator, but I'd rather group my batches by an extra, not-the-label column (eg, "grouping_id" or "topic")

tomaarsen commented 5 months ago

@dadamson I'd like to add a batch_sampler_kwargs option to the Training Arguments that overrides the default options, but I'm still brainstorming it (it's a bit complex with the multi-dataset case as you might want different kwargs for different datasets, and you might want different kwargs for train vs eval vs test).

In the meantime, feel free to override the Trainer to specify your own custom sampler: I purposefully implemented get_batch_sampler in a separate method so it could be more easily overridden.