keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Update the eager_build logic for jax trainer. #845

Closed qlzh727 closed 1 year ago

qlzh727 commented 1 year ago

The current eager_build will use eager tensor to run the forward path of the model to trigger variable creation, but this will bring non-trivial memory consumption from the forward path, and will cause OOM when the model are large, and designed for distribution computation.

This PR will use the kerasTensor instead, which doesn't consume any memory.

Verified with local colab for memory consumption.

codecov[bot] commented 1 year ago

Codecov Report

Patch coverage: 100.00% and project coverage change: +5.07% :tada:

Comparison is base (67722d7) 70.92% compared to head (d32e52e) 76.00%. Report is 2 commits behind head on main.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #845 +/- ## ========================================== + Coverage 70.92% 76.00% +5.07% ========================================== Files 344 328 -16 Lines 33325 31103 -2222 Branches 6417 6052 -365 ========================================== + Hits 23635 23639 +4 + Misses 8092 5866 -2226 Partials 1598 1598 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/845/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/845/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `75.91% <100.00%> (?)` | | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/845?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/backend/jax/trainer.py](https://app.codecov.io/gh/keras-team/keras-core/pull/845?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2pheC90cmFpbmVyLnB5) | `96.01% <100.00%> (+0.04%)` | :arrow_up: | ... and [17 files with indirect coverage changes](https://app.codecov.io/gh/keras-team/keras-core/pull/845/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team)

:umbrella: View full report in Codecov by Sentry.

:loudspeaker: Have feedback on the report? Share it here.

qlzh727 commented 1 year ago

The unit test failed after we change to use tree.map_structure(), probably it handles None differently than jax.tree_map(). Should I still prefer tree.map_structure()?

qlzh727 commented 1 year ago

The unit test failed after we change to use tree.map_structure(), probably it handles None differently than jax.tree_map(). Should I still prefer tree.map_structure()?

Add an additional None handle for this.