pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.51k stars 984 forks source link

Loaded model cannot reproduce the predictions of the saved model #3137

Open kaltinel opened 2 years ago

kaltinel commented 2 years ago

Hi, I am trying to save and load my model, and I believe there is a problem regarding its implementation. I train my model and test on a dataset. Then I retrieve some diagnostics for its performance (precision, recall etc). Then, I save my model. After, I load my model to test on the same exact test dataset to control if my saving / loading is correct. And I believe it is somehow not, as the diagnostics does not match up with each other.

For example, the accuracy metrics for the trained model for three epochs are:

Precision       |      Recall   |   Specificity   |      FPR        | F1_Score      |  Accuracy
0.70            |      0.63     |       0.40      |     0.59        |     0.66      |    0.56
0.68            |      0.65     |       0.31      |     0.68        |     0.66      |    0.55
0.68            |      0.48     |       0.49      |     0.50        |     0.56      |    0.48 *

After I save this model, I load the model, guide and optimizer on a different script and try Predictive on the same test data. My results from the loaded model are as shown:

Precision | Recall | Specificity | FPR      | F1_Score  | Accuracy
0.70      | 0.51   | 0.51        | 0.48     | 0.59      | 0.51

I would expect the last line of the trained model (with *) on the test data to be same as the loaded model's results, as they are in the same seed, tested on the same data. And of course, the loaded model doesn't get further training, it only gets Predictive used on it once.

I save my model as:

torch.save({"model" : mymodels.state_dict(), "guide" : mymodels.guide}, path_to_model) # save the model and guide params
pyro.get_param_store().save(path_to_parameters) # save parameters
adam.save(path_to_optim) # save the optimizer

And I load it as:

pyro.clear_param_store() ## clear
## load 
saved_model_dict = torch.load(path_to_model) 
mymodels.load_state_dict(saved_model_dict['model']) 
mymodels.guide = saved_model_dict['guide']

pyro.get_param_store().load(path_to_parameters)
adam.load(path_to_optim)
svi = SVI(mymodels.model, config_enumerate(mymodels.guide, "parallel", expand=True), adam, TraceEnum_ELBO(max_plate_nesting=2, strict_enumeration_warning=False))

And then call Predictive: predictive = Predictive(svi.model, guide=svi.guide, num_samples=args.numsampling) I cannot understand why I see this behaviour, and think it may be a bug.. I am glad to have your insights.Thank you.

fritzo commented 2 years ago

Hi @kaltinel just to clarify, can you confirm that you are calling pyro.set_rng_seed(...my_seed...) before running predictive(...) in each case?

Also, is there a reason you're loading adam and creating an svi instance, rather than directly constructing the predictive?

predictive = Predictive(model, guide, num_samples=args.num_samples)
kaltinel commented 2 years ago

Hi, thank you for your answer. Yes, the seed was the first possible culprit I thought of, so I confirm that the pyro.set_rng_seed(#) is the same in each case, and has been run before running Predictive().

For predictive construction, there is no specific reason of me doing as such. I thought I would get my model,guide,and the parameters and the optimizer saved, so that I can load it with all required elements. Would you suggest that I would go for this to save:

torch.save({"model" : mymodels.state_dict(), "guide" : mymodels.guide}, path_to_model) # save the model and guide params
pyro.get_param_store().save(path_to_parameters) # save parameters

And, to load:

saved_model_dict = torch.load(path_to_model) 
mymodels.load_state_dict(saved_model_dict['model']) 
mymodels.guide = saved_model_dict['guide']
pyro.get_param_store().load(path_to_parameters)
predictive = Predictive(mymodels.model, guide=mymodels.guide, num_samples=args.numsampling)

If so, can you explain why? (because no need for optimizer, no need for it to be loaded etc?)

fritzo commented 2 years ago

Would you suggest

Whatever works for you, I was only looking for a minimal reproducible example and would have expected something like

torch.save(model, path_to_model)
torch.save(guide, path_to_guide)
pyro.get_param_store().save(path_to_parameters)

pyro.clear_param_store()

model = torch.load(path_to_model)
guide = torch.load(path_to_guide)
pyro.get_param_store().load(path_to_parameters)

predictive = Predictive(model, guide=guide, num_samples=args.numsampling)

But again, whatever works for you.

Back to your main problem 🙂 could you try cranking up the number of samples? That should help distinguish whether the problem is due to random noise or actually bad parameters. If your results differ with large num_samples, then we could try to diff the param store before and after, e.g.:

def get_snapshot():
    return {k: v.detach().clone() for k, v in pyro.get_param_store().items()}

snapshot1 = get_snapshot()
...save however you like...
pyro.clear_param_store()
del model, guide, svi, predictive  # ensure a fresh environment
...load however you like...
snapshot2 = get_snapshot()

# check identity
assert set(snapshot1) == set(snapshot2), "keys differ"
for k, v1 in snapshot1.items():
    v2 = snapshot2[k]
    assert torch.allclose(v1, v2), f"mismatch at key {repr(k)}"

if that fails we could try something similar with snapshots of model.named_parameters() and guide.named_parameters().

Thanks for diving into the debugging!

kaltinel commented 2 years ago

Hi, Thank you for your detailed answer. I tried increasing the number of sampling and made sure that test dataset is -for-sure- the same, by making an external variable that parses the test file (apologies for not being able to increase the number of samples for this time point), and the diagnostics are still different from each other...

Then I tried your suggestion of comparing the parameters of the param_store, and the snapshots are the same, which I believe is great! However, it made me perhaps more confused as now I really cannot comprehend why I have different Predictive results from the same dataset, same model and guide, and same parameters....

I look forward to your feedback, thank you!

fritzo commented 2 years ago

Hmm, I'm not sure what else to try. Is there any way you could make a reproducible example we could look at?

kaltinel commented 2 years ago

Hi, I am sorry for the time lapse in between replies. I was trying to solve the issue. (Unfortunately I am unable to share my code due to privacy regulations on the matter.)

However I possibly found the culprit: I saved the model, guide and parameters of the model in the 'testing' script and compared with the one that is loaded from the training script. The guide and parameters are the matching with each other, however the model which loaded to testing script (the one generated during training) is not the same with the one which is saved from the testing script.

I am puzzled: How come I can save / load the guide and parameters exactly the same, but not the model?..I cant see how model changes during testing..

As I mentioned in my first comment, I use mymodels.state_dict() to save and mymodels.load_state_dict(saved_model_dict['model']) to load.

I look forward to have your insights on the matter.

Thank you.

fritzo commented 2 years ago

What are the differences between your original model and saved-then-loaded model? Do you save any randomly-generated tensors in the model? Have you considered using torch.save() and torch.load() for the whole mymodel (parameters and guide), as in

torch.save(mymodel, "mymodel.pt")
mymodel2 = torch.load("mymodel.pt")
kaltinel commented 2 years ago

Thank you for your reply.

I indeed saved my model with torch.save() and the issue was, at the end of the day, the seed and the position of SVIinstance.

Maybe you can update your documentation regarding this? The documentation of pyro model loading / saving is - to my knowledge- not easy to be found, and I found these information from colleagues. I would appreciate to have more detailed documentation. I am sure other users will, too :)