huggingface / setfit

Efficient few-shot learning with Sentence Transformers
https://hf.co/docs/setfit
Apache License 2.0
2.22k stars 220 forks source link

SetFit sampling trick #258

Closed danstan5 closed 1 year ago

danstan5 commented 1 year ago

Background

I've enjoyed using the SetFit library to get great results on text-classification tasks, thank you core developers for your work on this πŸ‘ However with larger datasets I'm working on (>100 classes, >20k samples) I found the training times to be slow, especially if wanting to hyperparameter tune as well.

Proposal

The vast majority of the training time is in contrastive learning stage which relates to the number of sentence-pairs included. There is probably a lot of ways to engineer the contrastive sentence-pairs, but in looking to improve training times a simple solution I fell upon (that fit nicely into the existing API) was just to remove duplicate pairs by adding a remove_duplicate_samples parameter in the setfit train method.

Testing

Using the run_fewshot.py script (with some alterations to setup the no-duplicate runs) I have run this comparison on the test_set datasets with the following parameters:

Commands for reproducibility `Original` ``` python scripts/setfit/run_fewshot.py --num_iterations=20 --batch_size=32 --train_time=true --is_test_set=true --add_data_augmentation=true ``` `no_duplicate_samples` ``` python scripts/setfit/run_fewshot.py --num_iterations=40 --batch_size=32 --train_time=true --is_test_set=true --add_data_augmentation=true --remove_duplicate_samples=true --exp_name=remove-dups ```

Analysis

Accuracy (16 samples per class): ag_news (acc) emotion (acc) enron_spam (acc) SentEval-CR (acc) sst5 (acc) amazon_counterfactual_en (matthews correlation)
SetFit Original 86.3 (1.0) 63.8 (2.3) 93.0 (1.8) 88.9 (0.9) 47.9 (2.3) 35.3 (7.1)
SetFit no_duplicate_samples 86.0 (0.9) 59.5 (2.7) 92.5 (1.3) 90.8 (1.4) 45.7 (2.5) 39.4 (8.3)
Training time (16 samples per class): ag_news (s) emotion (s) enron_spam (s) SentEval-CR (s) sst5 (s) amazon_counterfactual_en (s)
SetFit Original 41.0 (5.2) 34.7 (0.8) 54.3 (0.6) 11.8 (1.4) 28.8 (2.0) 13.9 (1.8)
SetFit no_duplicate_samples 14.3 (2.1) 13.9 (0.4) 11.7 (0.4) 2.8 (0.4) 10.8 (0.8) 3.4 (0.5)

SetFit % change accuracy: Original β†’ no_duplicate_samples

SetFit training time difference: Original / no_duplicate_samples

Highlights

Other comments

domitix commented 1 year ago

Hi, I'm very interested in your trick. How did you modify the code to add remove_duplicate_samples? Thank you in advance!

tomaarsen commented 1 year ago

@domitix I believe #259 implements the changes to reproduce these results.

@danstan5 Great analysis! I've personally always had some issues with the pair generation's naive approach. It's very interesting to see these results in practice.

danstan5 commented 1 year ago

@tomaarsen following on from your comments in #259 I'm inclined to explore more your idea of generating the "proper" list, but perhaps go even more fundamental:

In this case as we start to duplicate samples, is this not the same over-fitting to these randomly extra selected pairs? Would it be better to just do away with this sampling iterations concept and take the true total. combinations as "no. of samples" then ↑ num_epochs or learning_rate for fitting?

Be keen to know your thoughts, although I think testing will help to validate this!

tomaarsen commented 1 year ago

In my example, num_samples=16 was the total number of samples across the two classes. It was inspired by a test case from the first code block in the README. Having 16 total labeled sentences results in 136 unique pairs (or 120 without identical pairs), despite 640 samples being generated. I believe this is quite backwards, as I don't think a single epoch should train the same samples multiple times.

If we implement unique_pairs, then I believe that we should take the total number of unique pairs as a strict maximum on the number of samples/steps per epoch. In our example, that would be 120 (or 136). If extra training is required, then the number of epochs should be incremented. That way, no pairs get extra weight and we lose the odd behaviour that training is frequently done in 1 epoch containing duplicate training samples.


Optionally, we may give a warning if num_iterations * num_samples * 2 > n_unique_pairs, and limit the number of samples to n_unique_pairs. I'm wary of this however, as this warning is pretty much impossible to avoid if you want to train for all unique pairs.

cc: @lewtun thoughts on an optional unique_pairs parameter? See also #259 for additional discussion.

danstan5 commented 1 year ago

Closing this as the advantages that come through better sampling and natural limits have been addressed in #268.

Note: in hindsight alot of the speed up that came from removing duplicate samples was because add_data_augmentation adds no. of new samples = no. of existing samples (this doubles the dataset size with lots of duplicates!). Therefore what the analysis really shows it that at higher sample sizes, adding lots of duplicate augmented data has no impact on accuracy, but takes significantly longer to train.