keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.06k stars 19.35k forks source link

Example Doubt #19683

Closed emi-dm closed 1 week ago

emi-dm commented 1 week ago

Can someone explain me why CLS token is not included in this example and how I could include it for any backend?https://keras.io/examples/vision/image_classification_with_vision_transformer/

sineeli commented 1 week ago

Hi @emi-dm,

This design is inherited from the Transformer model for text, and we use it throughout the main
paper. An initial attempt at using only image-patch embeddings, globally average-pooling (GAP)
them, followed by a linear classifier—just like ResNet’s final feature map—performed very poorly.
However, we found that this is neither due to the extra token, nor to the GAP operation. Instead  
the difference in performance is fully explained by the requirement for a different learning-rate

Taken from ViT paper

with CLS and without CLS ViT can be constructed as per the paper. In case you want to use CLS token create a extra token embedding of ViT hidden dimension(d_model) and prepend to the Porojected Patches.

The attached new embedding can be considered as a separate single keras layer with a weight vector and this can work with all backends.

Example

class TokenLayer(keras.layers.Layer):

    def build(self, input_shape):
        self.cls_token = self.add_weight(
            name='cls',
            shape=(1, 1, input_shape[-1]),
            initializer='zeros'
        )

    def call(self, inputs):
        cls_token = self.cls_token + keras.ops.zeros_like(inputs[:, 0:1]) 
        out = keras.layers.Concatenate(axis=1)([cls_token, inputs])

        return out

Thanks and hope this helps.

emi-dm commented 1 week ago

Thank you so much @sineeli!!! I couldn't dept the necessary into the original paper, so this caused my doubt! Really appreciated :)

google-ml-butler[bot] commented 1 week ago

Are you satisfied with the resolution of your issue? Yes No