google-deepmind / kfac-jax

Second Order Optimization and Curvature Estimation with K-FAC in JAX.
Apache License 2.0
250 stars 23 forks source link

Making pmap axis names consistent in examples code to support things like cross-replica batch norm layers. #301

Open copybara-service[bot] opened 4 days ago

copybara-service[bot] commented 4 days ago

Making pmap axis names consistent in examples code to support things like cross-replica batch norm layers.