LeeDoYup / FixMatch-pytorch

Unofficial Pytorch code for "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence" in NeurIPS'20. This repo contains reproduced checkpoints.
MIT License
190 stars 35 forks source link

FixMatch-pytorch

Unofficial pytorch code for "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence," NeurIPS'20.
This implementation can reproduce the results (CIFAR10 & CIFAR100), which are reported in the paper.
In addition, it includes trained models with semi-supervised and fully supervised manners (download them on below links).

Requirements

Results: Classification Accuracy (%)

In addition to the results of semi-supervised learning in the paper, we also attach extra results of fully supervised learning (50000 labels, sup only) + consistency regularization (50000 labels, sup+consistency).
Consistency regularization also improves the classification accuracy, even though the labels are fully provided.
Evaluation is conducted by EMA (exponential moving average) of models in the SGD training trajectory.

CIFAR10

#Labels 40 250 4000 sup + consistency sup only
Paper (RA) 86.19 ± 3.37 94.93 ± 0.65 95.74 ± 0.05 - -
kekmodel - - 94.72 - -
valencebond 89.63(85.65) 93.08 94.72 - -
Ours 87.11 94.61 95.62 96.86 94.98
Trained Moels checkpoint checkpoint checkpoint checkpoint checkpoint

CIFAR100

#Labels 400 2500 10000 sup + consistency sup only
Paper (RA) 51.15 ± 1.75 71.71 ± 0.11 77.40 ± 0.12 - -
kekmodel - - - - -
valencebond 53.74 67.3169 73.26 - -
Ours 48.96 71.50 78.27 83.86 80.57
Trained Moels checkpoint checkpoint checkpoint checkpoint checkpoint

In the case of CIFAR100@40, the result does not reach the paper's result and is out of the confidence interval.
Despite the result, the accuracy with a small amount of labels highly depends on the label selection and other hyperparameters.
For example, we find that changing the momentum of batch normalization can give better results, closed to the reported accuracies.

Evaluation of Checkpoints

Download Checkpoints

In here, we attached some google drive links, which includes training logs and the trained models.
Because of security issues of google drive,
you may fail to download each checkpoint in the result tables by curl/wget.
Then, use gdown to download without the issues.

All checkpoints are included in this directory

Evaluation Example

After unzip the checkpoints into your own path, you can run

python eval.py --load_path saved_models/cifar10_400/model_best.pth --dataset cifar10 --num_classes 10

How to Use to Train

Important Notes

For the detailed explanations of arguments, see here.

Use single GPU

python train.py --rank 0 --gpu [0/1/...] @@@other args@@@

Use multi-GPUs (with DataParallel)

python train.py --world-size 1 --rank 0 @@@other args@@@

Use multi-GPUs (with distributed training)

When you use multi-GPUs, we strongly recommend using distributed training (even with a single node) for high performance.

With V100x4 GPUs, CIFAR10 training takes about 16 hours (0.7 days), and CIFAR100 training takes about 62 hours (2.6 days).

Run Examples (with single node & multi-GPUs)

CIFAR10

python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 4000 --save_name cifar10_4000 --dataset cifar10 --num_classes 10

CIFAR100

python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 10000 --save_name cifar100_10000 --dataset cifar100 --num_classes 100 --widen_factor 8 --weight_decay 0.001

To reproduce the results on CIFAR100, the --widen_factor has to be increased to --widen_factor=8. (see this issue in the official repo.), and --weight_decay=0.001.

Change the backbone networks

In this repo, we use WideResNet with LeakyReLU activations, implemented in models/net/wrn.py.
When you use the WideResNet, you can change widen_factor, leaky_slope, and dropRate by the argument changes.

For example,
If you want to use ReLU, just use --leaky_slope 0.0 in arugments.

Also, we support to use various backbone networks in torchvision.models.
If you want to use other backbone networks in torchvision, change the arguments
--net [MODEL's NAME in torchvision] --net_from_name True

when --net_from_name True, other model arguments are ignored except --net.

Mixed Precision Training

If you want to use mixed-precision training for speed-up, add --amp in the argument.
We checked that the training time of each iteration is reduced by about 20-30 %.

Tensorboard

We trace various metrics, including training accuracy, prefetch & run times, mask ratio of unlabeled data, and learning rates. See the details in here. You can see the metrics in tensorboard

tensorboard --logdir=[SAVE PATH] --port=[YOUR PORT]


Collaborator