Flax is a neural network library for JAX that is designed for flexibility.
6.15k
stars
649
forks
source link
Avoid using float32 in normalization for mean/var and scale/bias parameters when force_float32_reductions=False #4314
Closed
copybara-service[bot] closed 4 weeks ago
Avoid using float32 in normalization for mean/var and scale/bias parameters when force_float32_reductions=False