michaelkhany / liquid_time_constant_networks

Code Repository for Liquid Time-Constant Networks (LTCs)
Apache License 2.0
1 stars 1 forks source link

tensorflow 2.16.2 #1

Open panjea opened 5 days ago

panjea commented 5 days ago

works with tensorflow 2.14.0, fails with 2.16.2

Traceback (most recent call last):
  File "/home/rap/gir/gir.py", line 2015, in <module>
    import ctrnn_model
  File "/home/rap/gir/./liquid_time_constant_networks/ctrnn_model.py", line 7, in <module>
    import LTC4 as ltc # import the new ltc_models library
  File "/home/rap/gir/./liquid_time_constant_networks/LTC4.py", line 17, in <module>
    class LTCCell(tf.keras.layers.AbstractRNNCell):
AttributeError: module 'keras._tf_keras.keras.layers' has no attribute 'AbstractRNNCell'
panjea commented 5 days ago

i currently work around the issue by monkey patching tensorflow.keras.layers

## from: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/recurrent.py#L1003
class AbstractRNNCell(tf.keras.Layer):
    """Abstract object representing an RNN cell.

    See [the Keras RNN API guide](https://www.tensorflow.org/guide/keras/rnn)
    for details about the usage of RNN API.

    This is the base class for implementing RNN cells with custom behavior.

    Every `RNNCell` must have the properties below and implement `call` with
    the signature `(output, next_state) = call(input, state)`.


        class MinimalRNNCell(AbstractRNNCell):

            def __init__(self, units, **kwargs):
                self.units = units
                super(MinimalRNNCell, self).__init__(**kwargs)

            def state_size(self):
                return self.units

            def build(self, input_shape):
                self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                self.recurrent_kernel = self.add_weight(
                        shape=(self.units, self.units),
                self.built = True

            def call(self, inputs, states):
                prev_output = states[0]
                h = backend.dot(inputs, self.kernel)
                output = h + backend.dot(prev_output, self.recurrent_kernel)
                return output, output
This definition of cell differs from the definition used in the literature.
In the literature, 'cell' refers to an object with a single scalar output.
This definition refers to a horizontal array of such units.

An RNN cell, in the most abstract setting, is anything that has
a state and performs some operation that takes a matrix of inputs.
This operation results in an output matrix with `self.output_size` columns.
If `self.state_size` is an integer, this operation also results in a new
state matrix with `self.state_size` columns.  If `self.state_size` is a
(possibly nested tuple of) TensorShape object(s), then it should return a
matching structure of Tensors having shape `[batch_size].concatenate(s)`
for each `s` in `self.batch_size`.

def call(self, inputs, states):
    """The function that contains the logic for one RNN step calculation.

        inputs: the input tensor, which is a slide from the overall RNN input by
            the time dimension (usually the second dimension).
        states: the state tensor from previous step, which has the same shape
            as `(batch, state_size)`. In the case of timestep 0, it will be the
            initial state user specified, or zero filled tensor otherwise.

        A tuple of two tensors:
            1. output tensor for the current timestep, with size `output_size`.
            2. state tensor for next step, which has the shape of `state_size`.
    raise NotImplementedError('Abstract method')

def state_size(self):
    """size(s) of state(s) used by this cell.

    It can be represented by an Integer, a TensorShape or a tuple of Integers
    or TensorShapes.
    raise NotImplementedError('Abstract method')

def output_size(self):
    """Integer or TensorShape: size of outputs produced by this cell."""
    raise NotImplementedError('Abstract method')

def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
    return _generate_zero_filled_state_for_cell(self, inputs, batch_size, dtype)

tf.keras.layers.AbstractRNNCell = AbstractRNNCell