ShivamShrirao / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch
https://huggingface.co/docs/diffusers
Apache License 2.0
1.89k stars 505 forks source link

Prevent error saving partial weights (mixed precision) #177

Closed rafaelgm closed 1 year ago

rafaelgm commented 1 year ago

When using mixed precision and trying to save weights every N steps I was getting this error after the first save step:

RuntimeError: Input type (struct c10::Half) and bias type (float) should be the same

Adding keep_fp32_wrapper=True to the two unwrap_model calls on save_weights seems to fix the issue.