yitu-opensource / T2T-ViT

ICCV2021, Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
Other
1.15k stars 176 forks source link

Attention scale value from pretrained model #62

Open TooTouch opened 3 years ago

TooTouch commented 3 years ago

Hello

Thank you for providing pretrained weights.

The qk_scale was defined by embed_dim ** -0.5 in models/transformer_block.py. But, the attention scale value is (embed_dim // num_heads) ** -0.5 as I know.

@register_model
def t2t_vit_7(pretrained=False, **kwargs): # adopt performer for tokens to token
    if pretrained:
        kwargs.setdefault('qk_scale', 256 ** -0.5)
    model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=7, num_heads=4, mlp_ratio=2., **kwargs)
    model.default_cfg = default_cfgs['T2t_vit_7']
    if pretrained:
        load_pretrained(
            model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
    return model

Please check if I'm right or if you have any other intentions.