Open christian-steinmeyer opened 1 year ago
@christian-steinmeyer, The reason why this is failing is that the MHA layer has extra instructions when deserialized with from_config, which isn't called when initialized using num_heads and key_dim. https://github.com/keras-team/keras/blob/master/keras/layers/attention/multi_head_attention.py#L321
Without this line, the layer is marked as unbuilt (_built_from_signature = False), so it'll try to create new variables when called. This is why the layer appears to have newly initialized weights, instead of the checkpointed weights.
class MyCustomMhaLayer(keras.layers.Layer):
def __init__(self, embed_dim=None, num_heads=None, mha=None, **kwargs):
super().__init__(**kwargs)
if mha is None:
self.mha = keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
else:
self.mha = mha
def call(self, x, training=None):
return self.mha(x, x, x, training=training)
def get_config(self):
config = super().get_config()
config.update({
"mha": tf.keras.layers.serialize(self.mha)
})
return config
@classmethod
def from_config(cls, config):
config['mha'] = tf.keras.layers.deserialize(config['mha'])
return super().from_config(config)
Hey! Thanks for getting back to me! I'm not sure, I fully understand your response. Are you pointing me in the right direction for a potential fix and provide me with a workaround until it is fixed?
For my use case, it'd be enough to set self.built = True
in the _build_from_signature()
. Then it wouldn't re-build in __call__()
. I'm not sure about other implications of that change though. What do you think?
@sachinprasadhs, I was able to reproduce the issue on tensorflow v2.12, v2.13 and tf-nightly. Kindly find the gist of it here.
@christian-steinmeyer -
The recommended workflow for MultiHeadAttention
is to -
set_weights
method.For example -
config = mha.get_config()
new_mha = tf.keras.layers.MultiHeadAttention.from_config(config)
new_mha(*inputs) # To build the layer
new_mha.set_weights(weights)
new_mha(*inputs)
Verified these steps and confirm it gives same results. Please confirm and close the issue if it resolves for you.
Thanks for the recommendation. That does work. However, I believe that it does not resolve the issue. Layers overall allow setting the weights via config, but it doesn't work for the MHA layer. It fails further down the line and requires a bit of debugging before figuring out that it doesn't support something other layers do. So I believe it should either do a check in the from_config and be explicit about it not supporting the weights in config, or better: support it, too. I already found a one line fix and suggested it above. Perhaps there are negative consequences from that change, but could you verify that you at least considered this change?
@christian-steinmeyer - Agree, we could clarify in docstring documentation of MHA Layer. Will take that as action item for us.
System information.
Describe the problem.
When creating a
MultiHeadAttention
layer from config and providing weights, automatic setting of weights does not work (as in other layers).Describe the current behavior. ValueError in
set_weights
called from_maybe_build
due to non-matching weight lengths.Describe the expected behavior. No ValueError. Weights of the layer are initialized based on the provided weights in config and then used.
Contributing.
Standalone code to reproduce the issue.
Traceback.