martinsbruveris / tensorflow-image-models

TensorFlow port of PyTorch Image Models (timm) - image models with pretrained weights.
https://tfimm.readthedocs.io/en/latest/
Apache License 2.0
287 stars 25 forks source link

Incompatible shapes: [4] vs. [4,196] during finetuning ViT #62

Closed lorenzo-park closed 2 years ago

lorenzo-park commented 2 years ago

Hi, I was building a model using ViT by iterating through the layers, but got error Incompatible shapes: [4] vs. [4,196] when I call model.fit. Any ideas where the mismatch is happening? or it would be grateful if you guide me how to debug it (I am new to tensorflow). Here is the function for building a ViT model for finetuning.

def get_model(img_size=config.IMAGE_SIZE):
    with strategy.scope():
        inp = tf.keras.layers.Input(shape = [img_size, img_size, 3], name = 'inp1')
        label = tf.keras.layers.Input(shape = (), name = 'inp2')

        vit_model = tfimm.create_model("vit_base_patch16_224_miil_in21k", pretrained="timm",nb_classes=0)

        x = inp
        for layer in vit_model.layers:
            x = layer(x)

            # Some modification will be made here playing with x

        x = tf.keras.layers.Dense(config.N_CLASSES)(x)
        output = tf.keras.layers.Softmax(dtype='float32')(x)
        model = tf.keras.models.Model(inputs = [inp, label], outputs = [output])

        opt = tf.keras.optimizers.Adam(learning_rate = config.LR)

        model.compile(
            optimizer = opt,
            loss = [tf.keras.losses.SparseCategoricalCrossentropy()],
            metrics = [tf.keras.metrics.SparseCategoricalAccuracy(),tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)]
        )

    return model
martinsbruveris commented 2 years ago

Perhaps the problem could be found in these lines

for layer in vit_model.layers:
    x = layer(x)
    # Some modification will be made here playing with x

The models in tfimm are not functional models. They are subclassed models with a custom implementation of call(), see e.g., vit.

I am wondering if the modifications you are introducing don't play nicely with the expectations of call(). Note, e.g., that ViT adds positional encodings as weights, not keras layers.

lorenzo-park commented 2 years ago

Thank you for your reply. Now I understand what subclassed models and functional models are. The approach above seems not allowed for subclassed model. I was trying to access the intermediate outputs of the ViT model, especially the attention value of MHA layer. Could you give any suggestion on this?

lorenzo-park commented 2 years ago

Well, I end up modifying the original code and add a feed-forward function returning the MHA layer output.