Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458
stars
69
forks
source link
[NVIDIA] New collection for variables 'overwrite_with_gradient' #48
This PR supports the new var collection name: 'overwrite_with_gradient'. Such variables will be store in params but their gradients won't go to the optimizers and won't have corresponding opt states. Instead, their gradients will be used as the new variables in the next step.
This PR supports the new var collection name: 'overwrite_with_gradient'. Such variables will be store in
params
but their gradients won't go to the optimizers and won't have corresponding opt states. Instead, their gradients will be used as the new variables in the next step.There are four related PRs, and should be reviewed in this order: (1) https://github.com/google/praxis/pull/29 current-->(2) https://github.com/google/paxml/pull/48 (3) https://github.com/google/praxis/pull/28 (4) https://github.com/google/paxml/pull/49
cc. @pjannaty @reedwm @nluehr @lukaszlew