google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.06k stars 639 forks source link

Clarification for LSTMCell Documentation #4124

Open corentinlger opened 2 months ago

corentinlger commented 2 months ago

Hello, I was trying to understand the LSTMCell of Flax. The documentation for the __call__ function says:

carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.

It thought it was weird that the cell state wasn't returned in addition with the hidden state. But in the source code, initialize_carry seems to return a tuple containing the cell and the hidden states :

return (c, h)

Additionally, the __call__ function seems to also returns both the cell state and the hidden state in the carry:

return (new_c, new_h), new_h

Did I misunderstand something? If not, should the documentation be updated to clarify that the carry includes both the cell state and the hidden state?

Anyway, thanks for the great library!

cgarciae commented 2 months ago

I agree that the docs should be a bit more clear with the structure of the LSTMCell's carry.

corentinlger commented 2 months ago

Do you want me to do a PR for it ?