liamdm / FlowTransformer

GNU Affero General Public License v3.0
64 stars 14 forks source link

Fine Tuning Trouble #5

Open JamesNetworkingRoute opened 10 months ago

JamesNetworkingRoute commented 10 months ago

Hi Liam, I am trying to use your models as a base model for another application. Therefore, the first step was to implement the get_config definition inside the TransformerEncoderBlock:

def get_config(self):
        config = super().get_config()
        config.update({
            'input_dimension': self.input_dimension,
            'inner_dimension': self.inner_dimension,
            'num_heads': self.num_heads,
            'dropout_rate': self.dropout_rate,
            'use_conv': self.use_conv,
            'prefix': self.prefix,
            'attn_implementation': self.attn_implementation
        })
        return config

Then, I've saved the model in the main.py:

m.save('custom_base_model.h5')

And finally, I've created another file where I try to load the model by passing the TransformerEncoderBlock class as a custom object:

import tensorflow as tf
from implementations.transformers.basic.encoder_block import TransformerEncoderBlock
loaded_model = tf.keras.models.load_model("custom_base_model.h5", custom_objects={'TransformerEncoderBlock': TransformerEncoderBlock})

The first problem was due by this line in the TransformerEncoderBlock class: super().__init__(name=f"{prefix}transformer_encoder") That I've fixed by adding **kwargs:

kwargs['name'] = f"{prefix}transformer_encoder"
super().__init__(**kwargs)

But now, when I try to load the model, I got this error:

FlowTransformer\implementations\transformers\basic\encoder_block.py", line 113, in call  *
        x = inputs + attention_output
    ValueError: Dimensions must be equal, but are 290 and 128 for '{{node block_0_transformer_encoder/add}} = AddV2[T=DT_FLOAT](Placeholder, block_0_transformer_encoder/block_0_attention_dropout/Identity)' with input shapes: [?,8,290], [?,?,128].

What do you suggest for trying to load your models and be used for fine tuning?

liamdm commented 9 months ago

Sorry for the delay on this, been away over Christmas, I'll be able to look into this soon and get back to you James.