jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
https://drive.google.com/drive/folders/1mdiA1gf1IjPZNhk79I2cYUu6pwcH0OTD
2 stars 2 forks source link

`CheckpointSaver.best_model_dict` being overwritten! #38

Closed jyaacoub closed 1 year ago

jyaacoub commented 1 year ago

I became skeptical after observing that an already trained model which was being retrained started off with the same val loss as the final output line of the previous training log. And after some digging around in the code I discovered that I had not done a deepcopy for the checkpoints.


See note: https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-entire-model

If you only plan to keep the best-performing model (according to the acquired validation loss), don’t forget that best_model_state = model.state_dict() returns a reference to the state and not its copy! You must serialize best_model_state or use best_model_state = deepcopy(model.state_dict())...

This is a big issue since checkpoints are basically being overwritten by the final model performance which is likely overfit.