Closed erwallace closed 1 month ago
Hey! Thanks for reaching out.
Yeah this is possible with Lightning Callbacks. NVIDIA has implemented the EMA callback here. To use this in physicsml
, all you need to do is
pytorch_lightning
imports in that file with the new lightning.pytorch
(they're using an older version of Lightning).molflux
by doing
from molflux.modelzoo.models.lightning.trainer.callbacks.stock_callbacks import AVAILABLE_CALLBACKS
AVAILABLE_CALLBACKS["ExponentialMovingAverage"] = EMA
you should then be able to define it in the model's callback config. I havent tested how this works with multi gpus, but it seems that it should work fine.
We can make this part of the physicsml package by default, but we need to look at the licensing.
Hope this helps, let me know if you have any more questions!
Best, Ward
Hello,
Is it possible to add the option of using exponential moving average when updating weights during training? This is a feature in the MACE code and seems to be commonly used for NequIP and MACE. When training MACE models I have found its not uncommon for the errors in both training and validation to jump and take a long time to recover. I believe EMA would help fix this.
Thanks Ewan