wjmaddox / swa_gaussian

Code repo for "A Simple Baseline for Bayesian Uncertainty in Deep Learning"
BSD 2-Clause "Simplified" License
451 stars 81 forks source link

Cannot find key 'n_models' #18

Open LukasMosser opened 4 years ago

LukasMosser commented 4 years ago

Hi @wjmaddox!

I've been trying to reproduce the results for the segmentation experiment and have hit an error I cannot seem to fix. I'm using the commands in the readme to train a SWAG model and then to evaluate but I end up with the following error. Any idea what the reason could be?

python eval_ensemble.py --data_path /home/ec2-user/CamVid/ --batch_size 4 --method SWAG --scale=0.5 --loss cross_entropy --N 50 --file ./experiment_swag/checkpoint-1000.pt --save_path ./experiment_swag/output.npz

/home/ec2-user/CamVid/
Preparing model
Loading model ./experiment_swag/checkpoint-1000.pt

Traceback (most recent call last):
  File "eval_ensemble.py", line 146, in <module>
    model.load_state_dict(checkpoint["state_dict"])
  File "/home/ec2-user/swa_gaussian/swag/posteriors/swag.py", line 182, in load_state_dict
    n_models = state_dict["n_models"].item()

KeyError: 'n_models'
wjmaddox commented 4 years ago

Hi,

Just a word of caution here -- we never could really get the segmentation code to reproduce the results in the original Tiramisu paper so I don't know what you'll find there :(...

That being said, it looks like the issue is that you're trying to load a model that does not have a "n_models" buffer in the state dict, so how did you train it?

If you're confident that you indeed trained and are attempting to reload a SWAG model, make sure that the n_models buffer in the script is set to what you trained the model with, and add the strict=False flag as in: model.load_state_dict(checkpoint["state_dict"], strict=False)