Closed RossM closed 7 months ago
This patch affects models with a weight_dtype of float16 or bfloat16, and use_ema true.
Due to lack of precision, in fp16/bf16 updates to EMA weights will almost always be rounded to 0, meaning the EMA weights never change.
The solution is to always use float32 for the EMA model.
Describe your changes
This patch affects models with a weight_dtype of float16 or bfloat16, and use_ema true.
Due to lack of precision, in fp16/bf16 updates to EMA weights will almost always be rounded to 0, meaning the EMA weights never change.
The solution is to always use float32 for the EMA model.
Issue ticket number and link (if applicable)
Checklist before requesting a review