IntelLabs / matsciml

Open MatSci ML Toolkit is a framework for prototyping and scaling out deep learning models for materials discovery supporting widely used materials science datasets, and built on top of PyTorch Lightning, the Deep Graph Library, and PyTorch Geometric.
MIT License
144 stars 20 forks source link

Fixing checkpoint mechanism for exponential moving average #259

Closed laserkelvin closed 2 months ago

laserkelvin commented 2 months ago

This PR/commit fixes checkpoint saving in the round-trip process.

The problem lies with ExponentialMovingAverageCallback creating an ema_module attribute that is not normally part of tasks, otherwise leading to a lot of manual wrangling and remapping of state_dict when trying to load a checkpoint.

This PR addresses the issue by overriding on_save_checkpoint in BaseTaskModule and MultiTaskLitModule, where if there an ema_module exists at the time of checkpoint saving, we overwrite the corresponding state_dict tensors using the averaged weights.

This would impact training restarts since the weights are discontinuous, but the main intention of this is to facilitate easy reloading for inference.