google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.15k stars 649 forks source link

Avoid using float32 in normalization for mean/var and scale/bias parameters when force_float32_reductions=False #4314

Closed copybara-service[bot] closed 4 weeks ago

copybara-service[bot] commented 1 month ago

Avoid using float32 in normalization for mean/var and scale/bias parameters when force_float32_reductions=False