fadel / pytorch_ema

Tiny PyTorch library for maintaining a moving average of a collection of parameters.
MIT License
401 stars 25 forks source link

Add Feature: restore #2

Closed Zehui-Lin closed 3 years ago

fadel commented 3 years ago

Hi Zehui-Lin, thanks for your pull request. I will have time to evaluate it next week. For now I have the following in my mind.

It looks good to me as is, but I am wondering if we could achieve a similar feature without changing the behavior of copy_to(). The reasoning behind this being that for large models, users might be concerned about unintentional copies of parameters.

Maybe a store() method for collecting the params to be restored later with .restore()?

Zehui-Lin commented 3 years ago

Good suggestion! 👍