microsoft / tensorflow-directml-plugin

DirectML PluggableDevice plugin for TensorFlow 2
Apache License 2.0
179 stars 23 forks source link

Fix intermediate overflow in BatchNorm ops #321

Closed PatriceVignola closed 1 year ago

PatriceVignola commented 1 year ago

When using BatchNorm ops with mixed_float16, there are some overflow issues within the DirectML kernels that cause training to not converge or to generate NaN. Until this is resolve, we add a workaround to always run BatchNorm in float32 precision, and conver the result back to float16 when it's done.