Open rubensmau opened 1 year ago
Hello!
Many papers regarding few-shot (i.e. not a lot of training samples per class) methods consider K-shot learning in their results, where K is some fixed integer, commonly 8 or 16 and sometimes 4, 32 or 64. The SetFit paper is no exception.One ubiquitous finding is that a higher K always results in higher performance. That is assuming that K stays within reason, i.e. K is sufficiently low that we may still speak of "few-shot". See the following screenshot from the SetFit paper for an example:
As you can see, the performance universally increases as K increases from 8 to 64.
Additionally, this can be verified by using the very first script from the README and modifying the num_samples
as if it's K and plotting the results. I've done exactly that.
This results in the following graph:
Note that this script samples num_samples
elements from the much larger train_dataset
to simulate a situation with little training data. This means that especially at lower K, the quality of the few training samples that are sampled have a large impact on the performance, which is likely what explains the large drop from K=6 to K=7. As a result, taking averages over multiple runs for the same K would be required to get a smoother and more useful graph, but I don't have time for that. Beyond that, this graph should show my point decently well regardless.
As you can see, moving from 1 or 2 samples to e.g. 5 will cause notable improvements, while increases from e.g. 16 to 18 won't have as much of an influence. I suspect that the shape of this graph will be the same for all problems and datasets, but that the "sweet spot" of labelling time to performance gain will differ depending on the situation. As a result, assuming that you have e.g. 8 samples per class, it may be interesting to compute the performance if you plotted the performance when you "pretend" to have less samples. The slope of the resulting graph may give you an indication of the performance gain if you labeled another 2 samples per class. (That said, perhaps at that point it's better to not spend time writing a plotting script, but just spend the time on the labelling instead, hah!)
Please recognize that this script and the graph are very rudimentary :)
Thanks for your answer, Tom. But, I failed to detail my question better. Do you know similar studies regarding how the number of classes affects the overall performance, for instance, we need to increase the samples if more classes are added.
I'm currently training a setfit model with 4500 classes, 10 samples per class (using a proprietary dataset).
I think it is still generating pairs though? I just see endless tqdm bars haha
I can share the end accuracy once it gets there :)
Do you know similar studies regarding how the number of classes affects the overall performance, for instance, we need to increase the samples if more classes are added.
This is a good question, @rubensmau. I suspect that it's also impossible to answer in generality as it depends on how well the classes are separated. SetFit tries to organize the embeddings belonging to different classes so that they are separated well, so I would expect the number of classes to be relatively stable as long as the text separates the classes well.
Personally I don't know of any dataset-independent studies like that. You can cook examples where you don't need to increase your samples with the number of classes and others where you do need to do so. 🤔
I've noticed that an increase in classes makes it harder for the SetFit model to properly separate the classes in the embedding space. My experiments have shown that datasets with more classes generally improve in performance more slowly when more data is provided than datasets with fewer classes. For example, I have seen binary classification tasks where increasing from 16 to 32 labels per class gives marginal improvements, while classification tasks with 5 labels do increase in performance fairly significantly when moving from 16 to 32.
In fact, I can run an experiment with exactly that:
python .\scripts\setfit\run_fewshot.py --datasets sst2 sst5 --sample_sizes 2 4 8 16 32 --batch_size 64
This results in: | dataset | measure | 2_avg | 4_avg | 8_avg | 16_avg | 32_avg |
---|---|---|---|---|---|---|---|
sst2 | accuracy | 71.5% (9.1) | 77.2% (5.6) | 86.2% (3.3) | 90.5% (0.8) | 91.0% (0.9) | |
sst5 | accuracy | 32.9% (2.5) | 38.5% (2.6) | 42.6% (2.6) | 46.2% (1.8) | 48.1% (1.3) |
Each of these experiments were ran 10 times, and the average accuracies and standard deviations are shown.
Tom Aarsen
Hi, Sorry to interrupt guys, but I am facing a similar problem, I am dealing with multi class classification, I have my dataset that has 100 categories. The minimum number of examples per category is 30,
My dataset is quite huge and I will probably sub-sample it, it has 15k examples and 401 categories. I am planning to experiment with top 100 categories. maximum number of examples per category is around 700, while minimum is 30.
I'm currently training a setfit model with 4500 classes, 10 samples per class (using a proprietary dataset).
I think it is still generating pairs though? I just see endless tqdm bars haha
I can share the end accuracy once it gets there :)
Hi @logan-markewich - did you get useful results out of this? It's very similar to what I am trying but I get ~40%+ accuracy on the training data and ~0% accuracy on the evaluation data.
@grofte yea it never worked well for me either. I think the dataset is just too big haha
I have found that increasing the number of categories reduce the accuracy results. Has anyone studied how the increased number of samples per category affect the results?