keras-team / tf-keras

The TensorFlow-specific implementation of the Keras API, which was the default Keras from 2019 to 2023.
Apache License 2.0
62 stars 28 forks source link

Cannot create MultiHeadAttention layer from config with weights #157

Open christian-steinmeyer opened 1 year ago

christian-steinmeyer commented 1 year ago

System information.

Describe the problem.

When creating a MultiHeadAttention layer from config and providing weights, automatic setting of weights does not work (as in other layers).

Describe the current behavior. ValueError in set_weights called from _maybe_build due to non-matching weight lengths.

Describe the expected behavior. No ValueError. Weights of the layer are initialized based on the provided weights in config and then used.

Contributing.

Standalone code to reproduce the issue.

import tensorflow as tf

if __name__ == '__main__':
    mha = tf.keras.layers.MultiHeadAttention(num_heads=2, key_dim=64, name="mha")
    inputs = (tf.random.uniform((1, 2, 4)), tf.random.uniform((1, 2, 4)))
    mha(*inputs)
    weights = mha.get_weights()
    config = mha.get_config()
    config['weights'] = weights
    new_mha = tf.keras.layers.MultiHeadAttention.from_config(config)
    new_mha(*inputs)  # <-- fails

Traceback.

Traceback (most recent call last):
  File "<MY_SCRIPT>.py", line 11, in <module>
    new_mha(*inputs)
  File "<VIRTUAL_ENV>/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "<VIRTUAL_ENV>/lib/python3.10/site-packages/keras/src/engine/base_layer.py", line 1807, in set_weights
    raise ValueError(
ValueError: You called `set_weights(weights)` on layer "mha" with a weight list of length 8, but the layer was expecting 0 weights. Provided weights: [array([[[-0.02902268,  0.05308066,  0.10171321, -...
tilakrayal commented 1 year ago

@christian-steinmeyer, The reason why this is failing is that the MHA layer has extra instructions when deserialized with from_config, which isn't called when initialized using num_heads and key_dim. https://github.com/keras-team/keras/blob/master/keras/layers/attention/multi_head_attention.py#L321

Without this line, the layer is marked as unbuilt (_built_from_signature = False), so it'll try to create new variables when called. This is why the layer appears to have newly initialized weights, instead of the checkpointed weights.


class MyCustomMhaLayer(keras.layers.Layer):
    def __init__(self, embed_dim=None, num_heads=None, mha=None, **kwargs):
        super().__init__(**kwargs)
        if mha is None:
          self.mha = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        else:
          self.mha = mha

    def call(self, x, training=None):
        return self.mha(x, x, x, training=training)

    def get_config(self):
        config = super().get_config()
        config.update({
            "mha": tf.keras.layers.serialize(self.mha)
        })
        return config    

    @classmethod
    def from_config(cls, config):
      config['mha'] = tf.keras.layers.deserialize(config['mha'])
      return super().from_config(config)
christian-steinmeyer commented 1 year ago

Hey! Thanks for getting back to me! I'm not sure, I fully understand your response. Are you pointing me in the right direction for a potential fix and provide me with a workaround until it is fixed?

For my use case, it'd be enough to set self.built = True in the _build_from_signature(). Then it wouldn't re-build in __call__(). I'm not sure about other implications of that change though. What do you think?

tilakrayal commented 1 year ago

@sachinprasadhs, I was able to reproduce the issue on tensorflow v2.12, v2.13 and tf-nightly. Kindly find the gist of it here.

sampathweb commented 1 year ago

@christian-steinmeyer - The recommended workflow for MultiHeadAttention is to -

  1. Load the model / layer from Config
  2. Do a forward pass on the model / Layer
  3. Update the weights with set_weights method.
  4. Use the model / layer with the updated weights.

For example -

    config = mha.get_config()
    new_mha = tf.keras.layers.MultiHeadAttention.from_config(config)
    new_mha(*inputs) # To build the layer
    new_mha.set_weights(weights)
    new_mha(*inputs)

Verified these steps and confirm it gives same results. Please confirm and close the issue if it resolves for you.

christian-steinmeyer commented 1 year ago

Thanks for the recommendation. That does work. However, I believe that it does not resolve the issue. Layers overall allow setting the weights via config, but it doesn't work for the MHA layer. It fails further down the line and requires a bit of debugging before figuring out that it doesn't support something other layers do. So I believe it should either do a check in the from_config and be explicit about it not supporting the weights in config, or better: support it, too. I already found a one line fix and suggested it above. Perhaps there are negative consequences from that change, but could you verify that you at least considered this change?

sampathweb commented 1 year ago

@christian-steinmeyer - Agree, we could clarify in docstring documentation of MHA Layer. Will take that as action item for us.