fadel / pytorch_ema

Tiny PyTorch library for maintaining a moving average of a collection of parameters.
MIT License
401 stars 25 forks source link

State dict support #6

Closed Linux-cpp-lisp closed 3 years ago

Linux-cpp-lisp commented 3 years ago

This PR adds state_dict()/load_state_dict() methods to allow saving and later restoring the state of an EMA object. (This is useful, for example, when restarting training — especially at high decay, maintaining the shadow weights through a restart is important for avoiding artifacts in the validation loss as well has having the best final model.)

Based somewhat on state_dict()/load_state_dict() for torch.optim.Optimizer: https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer

fadel commented 3 years ago

Thanks again!