This affects models that have weight_dtype set to float16 or bfloat16, and use_ema is true.
Due to fp16/bf16 not having enough precision to represent differences to a weight on the order of 0.001 times the current value, almost all EMA weight updates get rounded down to 0. This means that the ema model never actually gets updated from the running weights.
The solution is to always keep the EMA model in float32, regardless of weight_dtype.
Issue ticket number and link (if applicable)
Checklist before requesting a review
[ + ] This is based on the /dev branch (Or a fork of it)
[ + ] This was created or at least validated using a proper IDE
[ + ] I have tested this code and validated any modified functions
[ N/A ] I have added the appropriate documentation and hint strings if adding or changing a user-facing feature
Describe your changes
This affects models that have weight_dtype set to float16 or bfloat16, and use_ema is true.
Due to fp16/bf16 not having enough precision to represent differences to a weight on the order of 0.001 times the current value, almost all EMA weight updates get rounded down to 0. This means that the ema model never actually gets updated from the running weights.
The solution is to always keep the EMA model in float32, regardless of weight_dtype.
Issue ticket number and link (if applicable)
Checklist before requesting a review