Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
456
stars
68
forks
source link
[NVIDIA] Simplify the unit test for overwrite_with_gradient #55
This pull request streamlines the unit test for the overwrite_with_gradient collection by narrowing its focus to verifying that variables marked in this collection are correctly replaced with their respective gradients. To achieve this, we have eliminated extraneous elements and concentrated on the test's core objective.
This pull request streamlines the unit test for the
overwrite_with_gradient
collection by narrowing its focus to verifying that variables marked in this collection are correctly replaced with their respective gradients. To achieve this, we have eliminated extraneous elements and concentrated on the test's core objective.cc. @zhangqiaorjc @lukaszlew