harubaru / waifu-diffusion

stable diffusion finetuned on weeb stuff
GNU Affero General Public License v3.0
1.94k stars 177 forks source link

Restore non-EMA weights after saving checkpoint #31

Closed john-sungjin closed 1 year ago

john-sungjin commented 1 year ago

Right now, whenever save_checkpoint() is called, the EMA parameters are copied to the UNet. However, the original parameters are not restored, meaning that training continues on the EMA parameters, which is undesired and means that training would be impacted by the frequency of args.save_steps.

Added store and restore functions to EMAModel to be able to retain the original parameters; these functions are pulled from the original CompVis code.