kekmodel / FixMatch-pytorch

Unofficial PyTorch implementation of "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence"
MIT License
758 stars 170 forks source link
deeplearning fixmatch pytorch randaugment semi-supervised-learning

FixMatch

This is an unofficial PyTorch implementation of FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. The official Tensorflow implementation is here.

This code is only available in FixMatch (RandAugment).

Results

CIFAR10

#Labels 40 250 4000
Paper (RA) 86.19 ± 3.37 94.93 ± 0.65 95.74 ± 0.05
This code 93.60 95.31 95.77
Acc. curve link link link

* November 2020. Retested after fixing EMA issues.

CIFAR100

#Labels 400 2500 10000
Paper (RA) 51.15 ± 1.75 71.71 ± 0.11 77.40 ± 0.12
This code 57.50 72.93 78.12
Acc. curve link link link

* Training using the following options --amp --opt_level O2 --wdecay 0.001

Usage

Train

Train the model by 4000 labeled data of CIFAR-10 dataset:

python train.py --dataset cifar10 --num-labeled 4000 --arch wideresnet --batch-size 64 --lr 0.03 --expand-labels --seed 5 --out results/cifar10@4000.5

Train the model by 10000 labeled data of CIFAR-100 dataset by using DistributedDataParallel:

python -m torch.distributed.launch --nproc_per_node 4 ./train.py --dataset cifar100 --num-labeled 10000 --arch wideresnet --batch-size 16 --lr 0.03 --wdecay 0.001 --expand-labels --seed 5 --out results/cifar100@10000

Monitoring training progress

tensorboard --logdir=<your out_dir>

Requirements

My other implementations

References

Citations

@misc{jd2020fixmatch,
  author = {Jungdae Kim},
  title = {PyTorch implementation of FixMatch},
  year = {2020},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/kekmodel/FixMatch-pytorch}}
}