TyXe-BDL / TyXe

MIT License
144 stars 33 forks source link

Errors when trying to load a VariationalBNN #15

Open francescofolino opened 2 years ago

francescofolino commented 2 years ago

Hi all, I'm new to TyXe, but I'm experimenting an issue when I'm trying to load a (previously) trained model from the disk.

To be more precise, the returned error is as in the following:

_raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for VariationalBNN: Unexpected key(s) in state_dict: net_guide.rnn.weight_ih_l0.locunconstrained etc.

In particular, to save the model, I use a code like this:

pyro.get_param_store().save(os.path.join(output_dir, "param_store.pt")) torch.save(model.state_dict(), os.path.join(output_dir, "best_mode.pt"))

To load the model (defined as tyxe.VariationalBNN(net, prior, likelihood, guide)) instead:

pyro.clear_param_store() model.load_state_dict(torch.load(os.path.join(save_model_path, "best_model.pt"))) pyro.get_param_store().load(os.path.join(save_model_path, "param_store.pt"))

Where is the error?

Thank you so much.

hpplyt commented 2 years ago

Hi,

Apologies for the slow response! I think this is just due to the variational parameter attributes being initialized lazily by Pyro. If your (deterministic) network doesn't have any buffers, i.e. only parameters, you shouldn't need to save/load the state dict and the param store should contain everything you need. Otherwise, if you do need to load the state dict, just run a forward pass through your BNN by calling guide_forward with some valid input data to initialize the parameter attributes before loading.

Let me know if neither option resolves the error, in that case I'd need to take a closer look at what's going on :)

Cam-B04 commented 1 year ago

Hi,

I have encountered the same error and I might have found a solution. You need to load the state dict using the .netattribute of your model :

pyro.clear_param_store()
model.net.load_state_dict(torch.load(os.path.join(save_model_path, "best_model.pt")))
pyro.get_param_store().load(os.path.join(save_model_path, "param_store.pt"))

Hope this helps !

francescofolino commented 1 year ago

Thanks Camille, I will try your solution 😉

Cheers, Francesco Il 19 ott 2022, 16:30 +0200, Camille Besombes @.***>, ha scritto:

Hi, I have encountered the same error and I might have found a solution. You need to load the state dict using the .netattribute of your model : pyro.clear_param_store() model.net.load_state_dict(torch.load(os.path.join(save_model_path, "best_model.pt"))) pyro.get_param_store().load(os.path.join(save_model_path, "param_store.pt")) Hope this helps ! — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

Cam-B04 commented 1 year ago

Forgot in my answer that it is necessary to save it as well as following : torch.save(model.net.state_dict(), os.path.join(output_dir, "best_mode.pt"))

freakontrol commented 1 year ago

Hi, I had the same issue and the solution of @Cam-B04 worked correctly, thank you.