LP-BNN CIFAR-10, CIFAR-100 official implementation using PyTorch BatchEnsemble CIFAR-10, CIFAR-100 unofficial implementation using PyTorch
Please if you use this code please cite the following papers:
see the requirement of CIFAR code In addition to this requirement our code needs a big GPU to have a big batch. Our code where tested and implemented on a V100 tesla GPU thanks to Jeanzay cluster. Please proceed to the installation of Cuda, and Pytorch to as explained on PyTorch web page to be able to use our code.
After you have cloned the repository, you can train each dataset of either cifar10, cifar100 by running the script below. To have better results we advise you to perform several trainings(minimum 3).
python main_LPBNN.py --dataset [cifar10/cifar100] --dirsave_out LPBNN_C10_T0
python main_LPBNN.py --dataset [cifar10/cifar100] --dirsave_out LPBNN_C10_T1
python main_LPBNN.py --dataset [cifar10/cifar100] --dirsave_out LPBNN_C10_T2
After you have cloned the repository, you can train each dataset of either cifar10, cifar100 by running the script below. To have better results we advise you to perform several trainings(minimum 3).
python main_BatchEnsemble.py --dataset [cifar10/cifar100] --dirsave_out BE_C10_T0
python main_BatchEnsemble.py --dataset [cifar10/cifar100] --dirsave_out BE_C10_T1
python main_BatchEnsemble.py --dataset [cifar10/cifar100] --dirsave_out BE_C10_T2
here are the comand line to test for CIFAR10. For CIFAR100 please adapt the code
python evaluate_uncertainty.py --algo 'BE' --dataset cifar10 --dirsave_out './checkpoint/cifar10/BE_C10_T'
python evaluate_uncertainty.py --algo 'LPBNN' --dataset cifar10 --dirsave_out './checkpoint/cifar10/LPBNN_C10_T'
Hyper-parameter | CIFAR-10 | CIFAR-100 |
---|---|---|
Ensemble size J | 4 | 4 |
initial learning rate | 0.1 | 0.1 |
batch size | 128 | 128 |
lr decay ratio | 0.1 | 0.1 |
lr decay epochs | 80, 160, 200 | 80, 160, 200 |
cutout | True | True |
SyncEnsemble BN | False | False |
Size of the latent space $ | 32 | 32 |
Below is the result of the test set accuracy for CIFAR-10 dataset training.
Accuracy is the average of 3 runs
network | Accuracy (%) | AUC | AUPR | FPR-95-TPR | ECE (%) | cA(%) | cE (%) |
---|---|---|---|---|---|---|---|
BatchEnsemble | 96.48 | 0.9540 | 0.9731 | 0.132 | 0.0167 | 47.44 | 0.2909 |
LP-BNN | 94.76 | 0.9670 | 0.9812 | 0.104 | 0.0148 | 69.92 | 0.2421 |
If you are interrested about the corrupted accuraccy and corrupted expected calibration error please download the dataset from CIFAR-10-C