nengo / keras-lmu

Keras implementation of Legendre Memory Units
https://www.nengo.ai/keras-lmu/
Other
207 stars 35 forks source link

`LMUFFT` cannot be created if `hidden_cell` is `Dense` #39

Open arvoelke opened 3 years ago

arvoelke commented 3 years ago

Versions:

Taking this example from the unit tests: https://github.com/nengo/keras-lmu/blob/ab0775791aa73f9d22780539594ef4bd7de0be25/keras_lmu/tests/test_layers.py#L158-L164 and modifying it as follows:

        out = layers.LMUFFT(
            1,
            2,
            3,
            tf.keras.layers.Dense(4),  # tf.keras.layers.SimpleRNNCell(4),
            return_sequences=True,
        )(inp)

results in the error:

ValueError: in user code:

    /home/arvoelke/git/keras-lmu/keras_lmu/layers.py:660 call  *
        h = tf.keras.layers.TimeDistributed(self.hidden_cell)(
    /home/arvoelke/anaconda3/envs/*/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py:1012 __call__  **
        outputs = call_fn(inputs, *args, **kwargs)
    /home/arvoelke/anaconda3/envs/*/lib/python3.8/site-packages/tensorflow/python/keras/layers/wrappers.py:244 call
        output_shape = self.compute_output_shape(input_shape).as_list()
    /home/arvoelke/anaconda3/envs/*/lib/python3.8/site-packages/tensorflow/python/keras/layers/wrappers.py:188 compute_output_shape
        child_output_shape = self.layer.compute_output_shape(child_input_shape)
    /home/arvoelke/anaconda3/envs/*/lib/python3.8/site-packages/tensorflow/python/keras/layers/core.py:1218 compute_output_shape
        raise ValueError(

    ValueError: The innermost dimension of input_shape must be defined, but saw: (None, None)

../../anaconda3/envs/*/lib/python3.8/site-packages/tensorflow/python/autograph/impl/api.py:670: ValueError
NarsimhaChilkuri commented 3 years ago

Although TimeDistributed(Dense(...))and Dense(...) are equivalent, replacing https://github.com/nengo/keras-lmu/blob/ab0775791aa73f9d22780539594ef4bd7de0be25/keras_lmu/layers.py#L660-L662 with

                h = self.hidden_cell(
                    h_in, training=training
                )

fixes the issue.