google-research / fixmatch

A simple method to perform semi-supervised learning with limited data.
Apache License 2.0
1.09k stars 172 forks source link

How's FixMatch compared to supervised methods? #52

Closed ElimsV closed 3 years ago

ElimsV commented 3 years ago

Thanks for sharing the great work.

It might be interesting to compare FixMatch with supervised learning methods, given the same labeled data and model structure, to further demonstrate the power of SSL.

I appreciate if you could comment on the idea or release experimental results on such comparison.

Regards,

carlini commented 3 years ago

I'm not quite sure I follow the question. Do you want to know how well our methods do in comparison to a fully supervised baseline that uses, say, just 40 or just 250 images? This is given in Table 9 and Table 10.

If instead you mean fully supervised on the entire dataset, then these numbers are reported in our ReMixMatch paper Appendix Section 4.1

To begin, we train a fully-supervised baseline to measure the highest accuracy we could hope to obtain with our training pipeline. The experiments we perform use the same model and training algorithm, so these baselines are valid for all discussed SSL techniques. On CIFAR-10, we obtain an fully-supervised error rate of 4.25% using weak flip + crop augmentation, which drops to 3.62% using AutoAugment and 3.91% using CTAugment. Similarly, on SVHN we obtain 2.70% error using weak (flip) augmentation and 2.31% and 2.16% using AutoAugment and CTAugment respectively. While AutoAugment performs slightly better on CIFAR-10 and slightly worse on SVHN compared to CTAugment, it is not our intent to design a better augmentation strategy; just one that can be used without a pre-training or tuning of hyper-parameters.

ElimsV commented 3 years ago

Actually I meant both. Thanks for the considerate response.

I implemented FixMatch with RandAugment in my own codebase and on a classification task, with real world data and imbalanced distribution (4 classes, majority:minority up to 50:1). The training dataset is composed of balanced/imbalanced labeled data and imbalanced unlabeled data (sampled randomly from the entire dataset).

I tried a coarse grid search for hyperparameters (weight decay, tao, lr, lambda_u, range of augmentations), but got no luck. FixMatch didn't really help to improve classification performance compared to a purely supervised model. I have a few observations:

  1. With too strong augmentation, the model cannot learn much from pseudo labels. Given large training epochs, the "unlabeled loss" multiplied by lambda_u, i.e. L_u lambda_u, still maintains a high value and oscillates fiercely, despite that the "labeled loss" quickly converges. L_u lambda_u and L_s start from the same order of magnitude, but differ by 1 - 2 orders of magnitude in the end.

  2. With balanced labeled data the training set, FixMatch achieves very similar precision but worse recall compared to naive supervised training. With imbalanced labeled data, FixMatch gives better precision but significantly worse recall.

Here are my guesses, please correct me if I'm wrong.

First, FixMatch should start the training process with a good or at least mediocre pre-trained model to prevent low quality pseudo labels introducing too much noise to the learning process. Second, FixMatch seems to have problems adapting to imbalanced classification tasks. The majority class is quickly overfitted. The randomly sampled unlabeled data cannot provide much useful info via consistency regularization and pseudo labeling. Third, hyperparameter tuning (maybe among all, augmentation range tuning is the most important) is vital and difficult in adaptation.

I appreciate any comment or suggestion for improvements. Thank you very much.

carlini commented 3 years ago

Yeah, unbalanced datasets are so far something that SSL really struggles with. We've found distribution matching from ReMixMatch can help here, but I don't know of any answers. My guess is this is less of any issue with fixmatch-as-expected, and more of an open research question for how to solve unaligned data distributions better.