Hello, I was trying to understand the LSTMCell of Flax. The documentation for the __call__ function says:
carry – the hidden state of the LSTM cell, initialized using LSTMCell.initialize_carry.
It thought it was weird that the cell state wasn't returned in addition with the hidden state. But in the source code, initialize_carry seems to return a tuple containing the cell and the hidden states :
return (c, h)
Additionally, the __call__ function seems to also returns both the cell state and the hidden state in the carry:
return (new_c, new_h), new_h
Did I misunderstand something? If not, should the documentation be updated to clarify that the carry includes both the cell state and the hidden state?
Hello, I was trying to understand the LSTMCell of Flax. The documentation for the
__call__
function says:It thought it was weird that the cell state wasn't returned in addition with the hidden state. But in the source code,
initialize_carry
seems to return a tuple containing the cell and the hidden states :Additionally, the
__call__
function seems to also returns both the cell state and the hidden state in the carry:Did I misunderstand something? If not, should the documentation be updated to clarify that the carry includes both the cell state and the hidden state?
Anyway, thanks for the great library!