Open waileong-leong opened 5 months ago
Coincidentally, I made a small PR to improve that class today: https://github.com/UKPLab/sentence-transformers/pull/2788
Part of the PR is to remove these lines of code since I also was not sure why they were there.
Hello!
Ai, this is indeed a bug. The goal of this snippet of code is to ensure that there is an even number of text values under groups
. The __iter__
code relies on this (and that batch_size % 2 == 0
) to ensure that each batch has 2+ texts from every label that occurs in the text.
E.g. if you have 17 texts from one label and a batch size of 16, then __iter__
would yield the 16 and keep the leftover 1 to use in the next batch. This can cause issues with losses where there must be at least 2 texts for that label.
I believe the code should instead be:
self.groups = {
label: sample_indices[:num_samples]
for label, sample_indices in groups.items()
if (num_samples := len(sample_indices) // 2 * 2)
}
This way we'll only throw away 1 sample at most per label.
This
[...] to ensure that each batch has 2+ texts from every label that occurs in the text.
should be
[...] to ensure that each batch has 2+ texts from every label that occurs in the batch
Right? Not trying to be pedantic, just trying to confirm that my understanding of the sampler is correct :) Then I could incorporate the above in this PR. Although it can of course also be solved earlier in a separate PR that aims to fix just that.
to ensure that each batch has 2+ texts from every label that occurs in the batch
Yes. That should be correct, indeed. The dataset can have labels and corresponding texts that do not get used in the batch, that should be fine.
I'd love to incorporate the groups
fix that I described above in your PR if possible, as I quite like the other changes that you're proposing @fpgmaas
@tomaarsen whats the plan or your view on batchSampler implementation? currently it is restricted to only the 3 type and allow minimal customisation, i have monkey patched a custom sampler to replicate the sentencelabeldataset for supervise contrastive loss
I came to the Issues tab to call out these two bugs (//2
and the hard-coded 'label'
column) while cobbling together my own subclass (and, unfortunately necessarily, a subclass of SentenceTransformerTrainer
to override get_batch_sampler()
) -- thanks for beating me to the punch!
The one extra feature I need for the SentenceTransformerTrainer
to use GroupByLabelBatchSampler
once this is bugfixed is the ability to set the sampler's valid_label_columns
from the TrainingArguments -- right now SentenceTransformerTrainer
takes it from the DataCollator
, but I'd rather group my batches by an extra, not-the-label column (eg, "grouping_id" or "topic")
@dadamson I'd like to add a batch_sampler_kwargs
option to the Training Arguments that overrides the default options, but I'm still brainstorming it (it's a bit complex with the multi-dataset case as you might want different kwargs for different datasets, and you might want different kwargs for train vs eval vs test).
In the meantime, feel free to override the Trainer to specify your own custom sampler: I purposefully implemented get_batch_sampler
in a separate method so it could be more easily overridden.
In the GroupByLabelBatchSampler, what is the reason of doing (num_samples := len(sample_indices) // 2) to determine the groups?
https://github.com/UKPLab/sentence-transformers/blob/2dee8c22f618df3f34f55d406bf97aa88383728a/sentence_transformers/sampler.py#L68-L72
if a label has 4 sentence {"0" : [0,1,2,3]}, it will only contain [0,1] in the group