keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

Fix JAX RNN backend issue. #924

Closed qlzh727 closed 10 months ago

qlzh727 commented 10 months ago

This PR address several issues:

  1. The existing RNN layer is not training properly due the usage of a fresh StatelessScope in the jax.lax.scan loop. This is causing all the trainable variables to miss the mapping to the actual value in the training loop. Update them to use the parent Stateless scope if it is there. This will address the training issue https://github.com/keras-team/keras-core/issues/322

  2. The RNN layers with dropout will have a RNG seed update in the step function, which is not allowed by the jax.lax.scan. We noticed this issue since the updated seed is traced for non-trainable variable, and raise error when we try to put sharding constraint for distribution. Added a new method to pre-populate the dropout mask on the layer and make the inner_loop to be stateless.

  3. During the unit test, I noticed the stackRNNCell doesn't work with existing RNNCell, since it unwrap the list for the state, make the call function to keep the list if the input state is a list.

  4. Expose the SimpleRNN|GRU|LSTM cells in the init.py since they are public API.

codecov[bot] commented 10 months ago

Codecov Report

Patch coverage: 80.95% and project coverage change: +0.03% :tada:

Comparison is base (c64de55) 79.73% compared to head (a340b01) 79.76%.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## main #924 +/- ## ========================================== + Coverage 79.73% 79.76% +0.03% ========================================== Files 318 318 Lines 28627 28645 +18 Branches 5447 5451 +4 ========================================== + Hits 22827 22850 +23 + Misses 4333 4332 -1 + Partials 1467 1463 -4 ``` | [Flag](https://app.codecov.io/gh/keras-team/keras-core/pull/924/flags?src=pr&el=flags&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras_core](https://app.codecov.io/gh/keras-team/keras-core/pull/924/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `79.69% <80.95%> (+0.03%)` | :arrow_up: | | [keras_core-numpy](https://app.codecov.io/gh/keras-team/keras-core/pull/924/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `60.40% <80.95%> (+0.01%)` | :arrow_up: | | [keras_core-tensorflow](https://app.codecov.io/gh/keras-team/keras-core/pull/924/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `66.84% <71.42%> (+0.02%)` | :arrow_up: | | [keras_core-torch](https://app.codecov.io/gh/keras-team/keras-core/pull/924/flags?src=pr&el=flag&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | `69.25% <71.42%> (+0.03%)` | :arrow_up: | Flags with carried forward coverage won't be shown. [Click here](https://docs.codecov.io/docs/carryforward-flags?utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#carryforward-flags-in-the-pull-request-comment) to find out more. | [Files Changed](https://app.codecov.io/gh/keras-team/keras-core/pull/924?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team) | Coverage Δ | | |---|---|---| | [keras\_core/backend/jax/rnn.py](https://app.codecov.io/gh/keras-team/keras-core/pull/924?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9iYWNrZW5kL2pheC9ybm4ucHk=) | `9.09% <33.33%> (+0.54%)` | :arrow_up: | | [keras\_core/layers/\_\_init\_\_.py](https://app.codecov.io/gh/keras-team/keras-core/pull/924?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvX19pbml0X18ucHk=) | `96.00% <100.00%> (+0.09%)` | :arrow_up: | | [keras\_core/layers/rnn/rnn.py](https://app.codecov.io/gh/keras-team/keras-core/pull/924?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvcm5uL3Jubi5weQ==) | `85.98% <100.00%> (+0.95%)` | :arrow_up: | | [keras\_core/layers/rnn/stacked\_rnn\_cells.py](https://app.codecov.io/gh/keras-team/keras-core/pull/924?src=pr&el=tree&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team#diff-a2VyYXNfY29yZS9sYXllcnMvcm5uL3N0YWNrZWRfcm5uX2NlbGxzLnB5) | `87.01% <100.00%> (+5.43%)` | :arrow_up: | ... and [1 file with indirect coverage changes](https://app.codecov.io/gh/keras-team/keras-core/pull/924/indirect-changes?src=pr&el=tree-more&utm_medium=referral&utm_source=github&utm_content=comment&utm_campaign=pr+comments&utm_term=keras-team)

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.