google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
456 stars 68 forks source link

[NVIDIA] Simplify the unit test for overwrite_with_gradient #55

Closed kaixih closed 8 months ago

kaixih commented 12 months ago

This pull request streamlines the unit test for the overwrite_with_gradient collection by narrowing its focus to verifying that variables marked in this collection are correctly replaced with their respective gradients. To achieve this, we have eliminated extraneous elements and concentrated on the test's core objective.

cc. @zhangqiaorjc @lukaszlew