tensorflow / addons

Useful extra functionality for TensorFlow 2.x maintained by SIG-addons
Apache License 2.0
1.69k stars 610 forks source link

`trainable` flag may not be toggled for an ESN layer #2809

Closed ghost closed 1 year ago

ghost commented 1 year ago

System information

Describe the bug

Setting trainable flag to False for an ESN layer (to make the reservoir of the resulting echo state network fixed) in a Keras model does not work. In particular, running

model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(),
    tfa.layers.ESN(units= 1000, spectral_radius=0.99),
    tf.keras.layers.Dense(1, kernel_initializer="lecun_normal")
])
model.layers[1].trainable = False

or

model = tf.keras.models.Sequential([
    tf.keras.layers.InputLayer(),
    tfa.layers.ESN(units= 1000, spectral_radius=0.99, trainable = False),
    tf.keras.layers.Dense(1, kernel_initializer="lecun_normal")
])

throws

raise ValueError("as_list() is not defined on an unknown TensorShape.")
ValueError: as_list() is not defined on an unknown TensorShape.
ghost commented 1 year ago

The issue was just the InputLayer in the absence of which both strategies above work.

pfaz69 commented 1 year ago

Hi, I don't think that setting trainable to True or False makes any difference for ESN, due to the hardcoded setting in esn_cell.py for state matrix (self.recurrent_kernel), input matrix (self.kernel) and bias (self.bias). See the block starting at 159:

    self.recurrent_kernel = self.add_weight(
        name="recurrent_kernel",
        shape=[self.units, self.units],
        initializer=_esn_recurrent_initializer,
        trainable=False,
        dtype=self.dtype,
    )
    self.kernel = self.add_weight(
        name="kernel",
        shape=[input_size, self.units],
        initializer=self.kernel_initializer,
        trainable=False,
        dtype=self.dtype,
    )

    if self.use_bias:
        self.bias = self.add_weight(
            name="bias",
            shape=[self.units],
            initializer=self.bias_initializer,
            trainable=False,
            dtype=self.dtype,
        )