keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.86k stars 19.44k forks source link

Outdated documentation regarding Stateful RNNs #20327

Open ageron opened 1 week ago

ageron commented 1 week ago

The documentation for the base RNN layer contains the following explanation, which is outdated:

Note on using statefulness in RNNs:

You can set RNN layers to be 'stateful', which means that the states computed for the samples in one batch will be reused as initial states for the samples in the next batch. This assumes a one-to-one mapping between samples in different successive batches.

To enable statefulness:

Specify stateful=True in the layer constructor. Specify a fixed batch size for your model, by passing If sequential model: batch_input_shape=(...) to the first layer in your model. Else for functional model with 1 or more Input layers: batch_shape=(...) to all the first layers in your model. This is the expected shape of your inputs including the batch size. It should be a tuple of integers, e.g. (32, 10, 100). Specify shuffle=False when calling fit(). To reset the states of your model, call .reset_states() on either a specific layer, or on your entire model.

Note on specifying the initial state of RNNs:

You can specify the initial state of RNN layers symbolically by calling them with the keyword argument initial_state. The value of initial_state should be a tensor or list of tensors representing the initial state of the RNN layer.

You can specify the initial state of RNN layers numerically by calling reset_states with the keyword argument states. The value of states should be a numpy array or list of numpy arrays representing the initial state of the RNN layer.

In particular:

However, I was unable to build a stateful RNN, I'm getting the following exception:

Epoch 1/10
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-100-aaee251b8e39>](https://localhost:8080/#) in <cell line: 3>()
      1 model.compile(loss="sparse_categorical_crossentropy", optimizer="nadam",
      2               metrics=["accuracy"])
----> 3 history = model.fit(stateful_train_set, validation_data=stateful_valid_set,
      4                     epochs=10, callbacks=[ResetStatesCallback(), model_ckpt])

1 frames
[/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

[/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs)
    120             # To get the full stack trace, call:
    121             # `keras.config.disable_traceback_filtering()`
--> 122             raise e.with_traceback(filtered_tb) from None
    123         finally:
    124             del filtered_tb

ValueError: Exception encountered when calling GRU.call().

Input tensor `sequential_10_1/gru_6_1/ReadVariableOp:0` enters the loop with shape (1, 128), but has shape (None, 128) after one iteration. To allow the shape to vary across iterations, use the `shape_invariants` argument of tf.while_loop to specify a less-specific shape.

Arguments received by GRU.call():
  • sequences=tf.Tensor(shape=(None, None, 16), dtype=float32)
  • initial_state=None
  • mask=None
  • training=True

I'm not sure whether this is a bug or whether I'm not implementing a stateful RNN correctly using Keras 3. If someone can please explain how to build one, I'm happy to update the documentation.

fchollet commented 5 days ago

Thanks for the report. Do you have a Colab or code snippet to reproduce the error?

Best I can tell the docs look outdated. The two corrections you made are accurate.

ageron commented 5 days ago

Thanks @fchollet , here's a little code snippet to reproduce the error:

import keras
import numpy as np

model = keras.Sequential([
    keras.layers.Input(batch_shape=[1, 10, 3]),
    keras.layers.LSTM(10, return_sequences=True, stateful=True),
    keras.layers.LSTM(10, return_sequences=True, stateful=True),
    keras.layers.Dense(5)
])

model.compile(loss="mse", optimizer="sgd")

X_train = np.random.rand(100, 10, 3)
y_train = np.random.rand(100, 10, 5)
model.fit(X_train, y_train, epochs=1)

I'm getting the same exception as above.

``` Epoch 1/5 --------------------------------------------------------------------------- ValueError Traceback (most recent call last) [](https://localhost:8080/#) in () 13 X_train = np.random.rand(100, 10, 3) 14 y_train = np.random.rand(100, 10, 5) ---> 15 model.fit(X_train, y_train, epochs=1) 1 frames [/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs) 120 # To get the full stack trace, call: 121 # `keras.config.disable_traceback_filtering()` --> 122 raise e.with_traceback(filtered_tb) from None 123 finally: 124 del filtered_tb [/usr/local/lib/python3.10/dist-packages/keras/src/utils/traceback_utils.py](https://localhost:8080/#) in error_handler(*args, **kwargs) 120 # To get the full stack trace, call: 121 # `keras.config.disable_traceback_filtering()` --> 122 raise e.with_traceback(filtered_tb) from None 123 finally: 124 del filtered_tb ValueError: Exception encountered when calling LSTM.call(). Input tensor `sequential_2_1/lstm_4_1/ReadVariableOp:0` enters the loop with shape (1, 10), but has shape (None, 10) after one iteration. To allow the shape to vary across iterations, use the `shape_invariants` argument of tf.while_loop to specify a less-specific shape. Arguments received by LSTM.call(): • sequences=tf.Tensor(shape=(None, 10, 3), dtype=float32) • initial_state=None • mask=None • training=True ```
ageron commented 5 days ago

Here's a gist notebook with the code above.

fchollet commented 5 days ago

Thanks for the code. The origin of the issue is a discrepancy between the batch size specified in Input and the batch size effectively received by the model (if you pass raw numpy data to fit(), it gets chunked into batches, configured by the batch_size argument).

You can just pass batch_size=1 in fit() to fix it (or otherwise use a generator-like or tf.data.Dataset-like data source)

We should have a check somewhere to prevent against such a mismatch.

ageron commented 4 days ago

Ah got it, thanks François. 👍

Indeed, the following code works fine:

model = keras.Sequential([
    keras.layers.Input(batch_shape=[1, 10, 3]),
    keras.layers.LSTM(10, return_sequences=True, stateful=True),
    keras.layers.LSTM(10, return_sequences=True, stateful=True),
    keras.layers.Dense(5)
])

model.compile(loss="mse", optimizer="sgd")

X_train = np.random.rand(100, 10, 3)
y_train = np.random.rand(100, 10, 5)
model.fit(X_train, y_train, epochs=1, batch_size=1)

So it's just a documentation issue, I'll update the name of this issue.

jeffcarp commented 3 days ago

Thanks, I'll send a doc update. @fchollet where would we implement this check?