Unofficial pytorch code for "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence," NeurIPS'20.
This implementation can reproduce the results (CIFAR10 & CIFAR100), which are reported in the paper.
In addition, it includes trained models with semi-supervised and fully supervised manners (download them on below links).
In addition to the results of semi-supervised learning in the paper, we also attach extra results of fully supervised learning (50000 labels, sup only) + consistency regularization (50000 labels, sup+consistency).
Consistency regularization also improves the classification accuracy, even though the labels are fully provided.
Evaluation is conducted by EMA (exponential moving average) of models in the SGD training trajectory.
#Labels | 40 | 250 | 4000 | sup + consistency | sup only |
---|---|---|---|---|---|
Paper (RA) | 86.19 ± 3.37 | 94.93 ± 0.65 | 95.74 ± 0.05 | - | - |
kekmodel | - | - | 94.72 | - | - |
valencebond | 89.63(85.65) | 93.08 | 94.72 | - | - |
Ours | 87.11 | 94.61 | 95.62 | 96.86 | 94.98 |
Trained Moels | checkpoint | checkpoint | checkpoint | checkpoint | checkpoint |
#Labels | 400 | 2500 | 10000 | sup + consistency | sup only |
---|---|---|---|---|---|
Paper (RA) | 51.15 ± 1.75 | 71.71 ± 0.11 | 77.40 ± 0.12 | - | - |
kekmodel | - | - | - | - | - |
valencebond | 53.74 | 67.3169 | 73.26 | - | - |
Ours | 48.96 | 71.50 | 78.27 | 83.86 | 80.57 |
Trained Moels | checkpoint | checkpoint | checkpoint | checkpoint | checkpoint |
In the case of CIFAR100@40, the result does not reach the paper's result and is out of the confidence interval.
Despite the result, the accuracy with a small amount of labels highly depends on the label selection and other hyperparameters.
For example, we find that changing the momentum of batch normalization can give better results, closed to the reported accuracies.
In here, we attached some google drive links, which includes training logs and the trained models.
Because of security issues of google drive,
you may fail to download each checkpoint in the result tables by curl/wget.
Then, use gdown to download without the issues.
All checkpoints are included in this directory
After unzip the checkpoints into your own path, you can run
python eval.py --load_path saved_models/cifar10_400/model_best.pth --dataset cifar10 --num_classes 10
For the detailed explanations of arguments, see here.
os.path.join(args.save_dir, args.save_name)
, after making new directory. If there already exists the path, the code will raise an error to prevent overwriting of trained models by mistake. If you want to overwrite the files, give --overwrite
.--hard_label False
. Also, you can adjust the sharping parameters --T (YOUR_OWN_VALUE)
.--resume --load_path [YOUR_CHECKPOINT_PATH]
. Then, the checkpoint is loaded to the model, and continues to training from the ceased iteration. see here and the related method.DataLoader
when distributed training with a single node having V100 GPUs x 4 is used.--p_cutoff
.torch.distributed.all_reduce
for BN buffers before this line.python train.py --rank 0 --gpu [0/1/...] @@@other args@@@
python train.py --world-size 1 --rank 0 @@@other args@@@
With V100x4 GPUs, CIFAR10 training takes about 16 hours (0.7 days), and CIFAR100 training takes about 62 hours (2.6 days).
single node
python train.py --world-size 1 --rank 0 --multiprocessing-distributed @@@other args@@@
multiple nodes (assuming two nodes)
# at node 0
python train.py --world-size 2 --rank 0 --dist_url [rank 0's url] --multiprocessing-distributed @@@@other args@@@@
# at node 1
python train.py --world-size 2 --rank 1 --dist_url [rank 0's url] --multiprocessing-distributed @@@@other args@@@@
python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 4000 --save_name cifar10_4000 --dataset cifar10 --num_classes 10
python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 10000 --save_name cifar100_10000 --dataset cifar100 --num_classes 100 --widen_factor 8 --weight_decay 0.001
To reproduce the results on CIFAR100, the --widen_factor
has to be increased to --widen_factor=8
. (see this issue in the official repo.), and --weight_decay=0.001
.
In this repo, we use WideResNet with LeakyReLU activations, implemented in models/net/wrn.py
.
When you use the WideResNet, you can change widen_factor, leaky_slope, and dropRate by the argument changes.
For example,
If you want to use ReLU, just use --leaky_slope 0.0
in arugments.
Also, we support to use various backbone networks in torchvision.models
.
If you want to use other backbone networks in torchvision, change the arguments
--net [MODEL's NAME in torchvision] --net_from_name True
when --net_from_name True
, other model arguments are ignored except --net
.
If you want to use mixed-precision training for speed-up, add --amp
in the argument.
We checked that the training time of each iteration is reduced by about 20-30 %.
We trace various metrics, including training accuracy, prefetch & run times, mask ratio of unlabeled data, and learning rates. See the details in here. You can see the metrics in tensorboard
tensorboard --logdir=[SAVE PATH] --port=[YOUR PORT]