SHI-Labs / Compact-Transformers

Escaping the Big Data Paradigm with Compact Transformers, 2021 (Train your Vision Transformers in 30 mins on CIFAR-10 with a single GPU!)
https://arxiv.org/abs/2104.05704
Apache License 2.0
495 stars 77 forks source link

Question: Why is sequence pooling more effective than a class token? #42

Closed eware-godaddy closed 2 years ago

eware-godaddy commented 2 years ago

Hi, I've read your excellent paper a few times now but I'm struggling on the intuition on why the sequence pooling approach should be substantially different or superior to a standard class token?

We've known for some time that the class token representations do not give good embedding representations, which is why it's common to pool the embedding representations through mean pooling - such as in Sentence Transformers.

But the use of another set of weights to compute attention weights for creating a pooled representation doesn't sound that different to what a class token should be doing.

A class token should in theory be able to pool representations from the previous layer.

What's the intuition, and what is different in the math or the architecture that would make this a more effective architecture, and lead to much efficient model sizes?

Been scratching my head about it for a number of days and thought I'd ask?

Thanks again for this great paper and library, and for enabling powerful transformers with a lot less parameters!

alihassanijr commented 2 years ago

Hi, Thank you for your interest.

The idea is, as you pointed out, to create better representations for the classifier. The sequence of tokens have to be pooled into a fixed size. Global average pooling (mean over tokens) and class token are among the most commonly used methods for that. However, pooling over the tokens in a more careful way can increase performance, not drastically, but noticeably.

Class token is able to pool over the tokens theoretically, but prepending it into the sequence from the very beginning does limit the power of self-attention in earlier layers to a certain degree. By removing this token, the encoder layers can focus on performing self-attention on the "patches" (or in the case of CCT feature maps) only. This helps model performance, specifically when focusing on data efficiency, as larger scale training makes the difference almost unnoticeable. This has also been shown by works concurrent to ours.

Since the objective here is small-scale learning here, what we are trying to accomplish is make a very powerful network less susceptible to overfitting, while also allowing for quicker convergence. Taking away a class token that is injected right at the beginning and pooling over the sequences through attention will help that.

I should also add that we conducted extensive experiments on different methods for "pooling", other than the conventional class token and sequence pooling, and sequence pooling does have an edge over the rest in the case of small-scale training. These differences do tend to shrink when there are a lot of augmentations and training techniques involved, which we disabled in our earliest experiments.

I hope this clarifies things.

eware-godaddy commented 2 years ago

Thanks @alihassanijr for the thorough response. I didn't consider how the class token might be distracting to the transformer, especially for smaller datasets. Thanks again for the great work.