keras-team / keras-io

Keras documentation, hosted live at keras.io
Apache License 2.0
2.74k stars 2.03k forks source link

Error in Vision Transformer examples #1907

Open angelo-ml opened 1 month ago

angelo-ml commented 1 month ago

Issue Type

Documentation Bug

Source

source

Keras Version

2.14

Custom Code

Yes

OS Platform and Distribution

Ubuntu 22.04

Python version

3.10

GPU model and memory

Nvidia RTX4070 (12GB)

Current Behavior?

Hi, I've spotted a mistake in the Vision Transformer examples in Keras.io [3,4,5,6,7].

In all five of the examples below, to build the ViT architecture, the authors use a single hyper-parameter named projection_dim, which is used both as the model's hidden dimension, and as the dimension for queries, keys, and values, in the multi-head attention layer. These two hyper-parameters they shouldn't be the same. However, according to [1], they are connected:

hidden dimension = number of heads * qkv dimension

One simple way to verify this issue, is to calculate the total number of trainable parameters of the model. Using the architecture from the examples in Keras.io, and setting the same hyper-parameters with vision transformer base, the model has only 15 million parameters (while the Vision Transformer Base has 86 million [2]).

To fix this issue:

  1. a hidden dimension parameter can be defined as: hidden_dim = projection_dim * num_heads

  2. The encoded patches should be projected in the hidden dimension, instead of the projection_dim: encoded_patches = PatchEncoder(num_patches, hidden_dim)(patches)

  3. The transformer_units should also use the hidden dimension: transformer_units = [hidden_dim * 2, hidden_dim, ]

Then, if the same hyper-parameters used as in the original paper, the number of trainable parameters will be the same, as in the ViT base.

I understand that the authors may have used alternative versions of the original model, but this particular modification, can change significantly the behaviour of the model.

If you'll need any further information, please let me know.

Best wishes, Angelos

[1] see table 3 in the original paper: https://arxiv.org/pdf/1706.03762 [2] https://arxiv.org/pdf/2010.11929 [3] https://keras.io/examples/vision/image_classification_with_vision_transformer/ [4] https://keras.io/examples/vision/vit_small_ds/ [5] https://keras.io/examples/vision/object_detection_using_vision_transformer/ [6] https://keras.io/examples/vision/token_learner/ [7] https://keras.io/examples/vision/vit_small_ds/

Standalone code to reproduce the issue or tutorial link

# Below is the ViT class (create_vit_object_detector, indentical to [3]),
# with the same hyper-parameters as ViT-base, including the hidden dimension hyper-parameter.
# If the hidden_dimension is set equal to projection_dim (as implied in the Keras.io examples)
# the model will have 15M parameters. 
# If set to 768 (=projection_dim*num_heads), it will have 86M parameters, as the original model.

# The code uses tensorflow 2.14 

#%% Import libraries
import keras
from keras import layers
import tensorflow as tf

#%% define required functions and classes
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        input_shape = tf.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = tf.image.extract_patches(images, 
                                           [1,self.patch_size,self.patch_size,1],
                                           [1,self.patch_size,self.patch_size,1],
                                           rates=[1, 1, 1, 1],
                                           padding="SAME"
        )
        patches = tf.reshape(
            patches,
            (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    # Override function to avoid error while saving model
    def get_config(self):
        config = super().get_config().copy()
        config.update(
            {
                "input_shape": input_shape,
                "patch_size": patch_size,
                "num_patches": num_patches,
                "projection_dim": projection_dim,
                "num_heads": num_heads,
                "transformer_units": transformer_units,
                "transformer_layers": transformer_layers,
                "mlp_head_units": mlp_head_units,
            }
        )
        return config

    def call(self, patch):
        positions = tf.expand_dims(
            tf.experimental.numpy.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

def create_vit_object_detector(
    input_shape,
    patch_size,
    num_patches,
    projection_dim,
    num_heads,
    transformer_units,
    transformer_layers,
    mlp_head_units,
    hidden_dimension
):

    inputs = keras.Input(shape=input_shape)
    # Create patches
    patches = Patches(patch_size)(inputs)
    # Encode patches
    encoded_patches = PatchEncoder(num_patches, hidden_dimension)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.GlobalAveragePooling1D()(representation)
    representation = layers.Dropout(0.3)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.3)

    bounding_box = layers.Dense(1024)(
        features
    )  # Final four neurons that output bounding box

    # return Keras model.
    return keras.Model(inputs=inputs, outputs=bounding_box)

#%% model parameters
image_size = 224
patch_size = 16
input_shape = (image_size, image_size, 3)  # input image shape
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 12
hidden_dimension = projection_dim * num_heads

# Size of the transformer layers
transformer_units = [
    hidden_dimension * 3,
    hidden_dimension,
]
transformer_layers = 12
mlp_head_units = [3072, 3072]  # Size of the dense layers

vit_object_detector = create_vit_object_detector(
    input_shape,
    patch_size,
    num_patches,
    projection_dim,
    num_heads,
    transformer_units,
    transformer_layers,
    mlp_head_units,
    hidden_dimension,
)

vit_object_detector.summary()

Relevant log output

No response

tobilab commented 6 days ago

Any progress here?