Second Order Optimization and Curvature Estimation with K-FAC in JAX.
250
stars
23
forks
source link
Making pmap axis names consistent in examples code to support things like cross-replica batch norm layers. #301
Open
copybara-service[bot] opened 4 days ago
Making pmap axis names consistent in examples code to support things like cross-replica batch norm layers.