google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

[NVIDIA] New collection for variables 'overwrite_with_gradient' #48

Closed kaixih closed 1 year ago

kaixih commented 1 year ago

This PR supports the new var collection name: 'overwrite_with_gradient'. Such variables will be store in params but their gradients won't go to the optimizers and won't have corresponding opt states. Instead, their gradients will be used as the new variables in the next step.

There are four related PRs, and should be reviewed in this order: (1) https://github.com/google/praxis/pull/29 current-->(2) https://github.com/google/paxml/pull/48 (3) https://github.com/google/praxis/pull/28 (4) https://github.com/google/paxml/pull/49

cc. @pjannaty @reedwm @nluehr @lukaszlew