Closed qlzh727 closed 10 months ago
Patch coverage: 80.95%
and project coverage change: +0.03%
:tada:
Comparison is base (
c64de55
) 79.73% compared to head (a340b01
) 79.76%.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
This PR address several issues:
The existing RNN layer is not training properly due the usage of a fresh StatelessScope in the jax.lax.scan loop. This is causing all the trainable variables to miss the mapping to the actual value in the training loop. Update them to use the parent Stateless scope if it is there. This will address the training issue https://github.com/keras-team/keras-core/issues/322
The RNN layers with dropout will have a RNG seed update in the step function, which is not allowed by the jax.lax.scan. We noticed this issue since the updated seed is traced for non-trainable variable, and raise error when we try to put sharding constraint for distribution. Added a new method to pre-populate the dropout mask on the layer and make the inner_loop to be stateless.
During the unit test, I noticed the stackRNNCell doesn't work with existing RNNCell, since it unwrap the list for the state, make the call function to keep the list if the input state is a list.
Expose the SimpleRNN|GRU|LSTM cells in the init.py since they are public API.