kumar-shridhar / PyTorch-BayesianCNN

Bayesian Convolutional Neural Network with Variational Inference based on Bayes by Backprop in PyTorch.
MIT License
1.42k stars 323 forks source link

Replicate paper results #43

Open maxstrobel opened 4 years ago

maxstrobel commented 4 years ago

Hi,

thanks for this nice work, I really appreciate it! I tried to replicate the results from your paper with the repository, but I have not succeeded.

First, I downloaded your repo and the datasets. Then I adapted the configuration for the Bayesian Networks:

############### Configuration file for Bayesian ###############
n_epochs = 100
lr_start = 0.001
num_workers = 4
valid_size = 0.2
batch_size = 256
train_ens = 10
valid_ens = 10

Finally I run the evaluation script with main_bayesian.py --net_type alexnet --dataset CIFAR10, but the network is not able to overcome a validation accuracy of around 58%:

Epoch: 20   Training Loss: 2502238.8933     Training Accuracy: 0.5926   Validation Loss: 23935792.3000  Validation Accuracy: 0.5635     train_kl_div: 2445767.9315
Validation loss decreased (25103334.600000 --> 23935792.300000).  Saving model ...
Epoch: 21   Training Loss: 2387042.4522     Training Accuracy: 0.5968   Validation Loss: 22808838.9500  Validation Accuracy: 0.5584     train_kl_div: 2330986.6338
Validation loss decreased (23935792.300000 --> 22808838.950000).  Saving model ...
Epoch: 22   Training Loss: 2274617.4682     Training Accuracy: 0.6079   Validation Loss: 21713194.5000  Validation Accuracy: 0.5725     train_kl_div: 2219840.2803
Validation loss decreased (22808838.950000 --> 21713194.500000).  Saving model ...
Epoch: 23   Training Loss: 2166076.8439     Training Accuracy: 0.6137   Validation Loss: 20656232.4500  Validation Accuracy: 0.5872     train_kl_div: 2112406.7866
Validation loss decreased (21713194.500000 --> 20656232.450000).  Saving model ...
Epoch: 24   Training Loss: 2061628.5510     Training Accuracy: 0.6183   Validation Loss: 19644658.9000  Validation Accuracy: 0.5701     train_kl_div: 2008751.9745
Validation loss decreased (20656232.450000 --> 19644658.900000).  Saving model ...
Epoch: 25   Training Loss: 1961447.8232     Training Accuracy: 0.6230   Validation Loss: 18666660.6500  Validation Accuracy: 0.5711     train_kl_div: 1908918.2803
Validation loss decreased (19644658.900000 --> 18666660.650000).  Saving model ...
Epoch: 26   Training Loss: 1864639.9626     Training Accuracy: 0.6289   Validation Loss: 17726859.6500  Validation Accuracy: 0.5758     train_kl_div: 1812952.8240
Validation loss decreased (18666660.650000 --> 17726859.650000).  Saving model ...
Epoch: 27   Training Loss: 1771119.3846     Training Accuracy: 0.6386   Validation Loss: 16825135.8500  Validation Accuracy: 0.5862     train_kl_div: 1720880.1863
Validation loss decreased (17726859.650000 --> 16825135.850000).  Saving model ...
Epoch: 28   Training Loss: 1682560.0645     Training Accuracy: 0.6406   Validation Loss: 15963687.3750  Validation Accuracy: 0.5892     train_kl_div: 1632709.2596
Validation loss decreased (16825135.850000 --> 15963687.375000).  Saving model ...
Epoch: 29   Training Loss: 1597318.2373     Training Accuracy: 0.6459   Validation Loss: 15150667.4000  Validation Accuracy: 0.5615     train_kl_div: 1548427.4435
Validation loss decreased (15963687.375000 --> 15150667.400000).  Saving model ...
Epoch: 30   Training Loss: 1516623.9817     Training Accuracy: 0.6498   Validation Loss: 14361168.3500  Validation Accuracy: 0.5829     train_kl_div: 1467998.1879
Validation loss decreased (15150667.400000 --> 14361168.350000).  Saving model ...
Epoch: 31   Training Loss: 1439714.2970     Training Accuracy: 0.6520   Validation Loss: 13613963.9500  Validation Accuracy: 0.5829     train_kl_div: 1391386.5470
Validation loss decreased (14361168.350000 --> 13613963.950000).  Saving model ...
Epoch: 32   Training Loss: 1366105.2030     Training Accuracy: 0.6600   Validation Loss: 12909336.8000  Validation Accuracy: 0.5755     train_kl_div: 1318524.9443
Validation loss decreased (13613963.950000 --> 12909336.800000).  Saving model ...
Epoch: 33   Training Loss: 1296600.1863     Training Accuracy: 0.6617   Validation Loss: 12236651.6000  Validation Accuracy: 0.5815     train_kl_div: 1249338.7006
Validation loss decreased (12909336.800000 --> 12236651.600000).  Saving model ...
Epoch: 34   Training Loss: 1230397.9889     Training Accuracy: 0.6638   Validation Loss: 11600143.9500  Validation Accuracy: 0.5893     train_kl_div: 1183742.5000
Validation loss decreased (12236651.600000 --> 11600143.950000).  Saving model ...
Epoch: 35   Training Loss: 1168005.8073     Training Accuracy: 0.6705   Validation Loss: 11004782.7250  Validation Accuracy: 0.5683     train_kl_div: 1121634.4037
Validation loss decreased (11600143.950000 --> 11004782.725000).  Saving model ...
Epoch: 36   Training Loss: 1109223.4610     Training Accuracy: 0.6687   Validation Loss: 10435876.3750  Validation Accuracy: 0.5749     train_kl_div: 1062898.7377
Validation loss decreased (11004782.725000 --> 10435876.375000).  Saving model ...
Epoch: 37   Training Loss: 1053834.6206     Training Accuracy: 0.6691   Validation Loss: 9895180.6000   Validation Accuracy: 0.5803     train_kl_div: 1007417.1760
Validation loss decreased (10435876.375000 --> 9895180.600000).  Saving model ...
Epoch: 38   Training Loss: 1001452.7830     Training Accuracy: 0.6708   Validation Loss: 9391186.8750   Validation Accuracy: 0.5642     train_kl_div: 955054.9248
Validation loss decreased (9895180.600000 --> 9391186.875000).  Saving model ...
Epoch: 39   Training Loss: 951858.5939  Training Accuracy: 0.6717   Validation Loss: 8913133.8750   Validation Accuracy: 0.5767     train_kl_div: 905697.4590
Validation loss decreased (9391186.875000 --> 8913133.875000).  Saving model ...
Epoch: 40   Training Loss: 905384.9124  Training Accuracy: 0.6734   Validation Loss: 8459427.5000   Validation Accuracy: 0.5760     train_kl_div: 859194.2727
Validation loss decreased (8913133.875000 --> 8459427.500000).  Saving model ...
Epoch: 41   Training Loss: 861708.2651  Training Accuracy: 0.6720   Validation Loss: 8040532.7500   Validation Accuracy: 0.5767     train_kl_div: 815417.2926
Validation loss decreased (8459427.500000 --> 8040532.750000).  Saving model ...
Epoch: 42   Training Loss: 820970.3232  Training Accuracy: 0.6684   Validation Loss: 7639982.5250   Validation Accuracy: 0.5765     train_kl_div: 774222.8085
Validation loss decreased (8040532.750000 --> 7639982.525000).  Saving model ...
Epoch: 43   Training Loss: 782426.8826  Training Accuracy: 0.6698   Validation Loss: 7267536.7375   Validation Accuracy: 0.5746     train_kl_div: 735472.4554

Can you explain, how to replicate the results from the paper?

maxstrobel commented 4 years ago

push

dhuruvapriyan commented 4 years ago

I am also facing the same issue.

kumar-shridhar commented 4 years ago

Hi, Adding a beta term constraint over the KL divergence loss solves this issue.

Loss = NLL + KL

Adding a beta value multiplied with KL will solve the convergence issue.

Loss = NLL + ß * KL

where ß is a hyper-parameter.

We will update the repo soon with a good way to set ß.

maxstrobel commented 4 years ago

Do you have a rough estimate, when you will update the repo, e.g. days / weeks / months?

Thanks!

kumar-shridhar commented 4 years ago

In a week.

ccpocker commented 4 years ago

could you explain the parameter of train_ens and val_ens? Does it mean sample_number?

kumar-shridhar commented 4 years ago

Yes, that's correct.

nikdn commented 2 years ago

In a week.

Hello, with the latest version of the code still replicating the validation accuracy of over 80 percent on the cifar10 as stated in the paper (figure 3) is not successful. Would you be able to guide me a bit toward the configuration and initialization of the parameters?

StevenLauHKHK commented 2 years ago

@kumar-shridhar How to reproduce the validation accuracy as stated in the paper on the cifar10 dataset? The network cannot break through 64% in validation accuracy when I use the same setting in your configuration file. Thank you.