UCDvision / sima

Official implementation for "SimA: Simple Softmax-free Attention for Vision Transformers"
MIT License
34 stars 4 forks source link

Replace multi head attention in decoder #3

Open Mareeta26 opened 2 years ago

Mareeta26 commented 2 years ago

Hi, May I know whether I can use sima instead of multi head attention in decoder, to reduce complexity?

Thanks!

soroush-abbasi commented 2 years ago

Hi!

You can replace SimA with any self-attention module. It may requires some parameters tuning for specific models. We tried on CvT, ViT and XCiT. It also works with DINO loss (self-supervised). I plan to try it with MAE in the future. When you said decoder, which model are you referring to (e.g, decoder of DETR or MAE)?

Thanks! Have a good day!

Mareeta26 commented 2 years ago

@soroush-abbasi Thanks for the reply. I meant the decoder of ConvTransformer, which incorporates convolutions in transformer.

Mareeta26 commented 2 years ago

@soroush-abbasi Also, is it possible to share the code for ViT with SimA? Thank you in advance!

soroush-abbasi commented 2 years ago

I guess it should work with decoder of ConvTransformer. You can simply replace self-attention with SimA attention (SimA class in below). To run with ViT/DeiT architecture, please replace these classes in sima.py as below (removing LPI layer, removing class attention layer):

class SimA(nn.Module):

    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        k = F.normalize(k, p=1.0, dim=-2)
        q = F.normalize(q, p=1.0, dim=-2)
        if (N / (C//self.num_heads)) < 1:
            x = ((q @ k.transpose(-2, -1)) @ v).transpose(1, 2).reshape(B, N, C)
        else:
            x = (q @ (k.transpose(-2, -1) @ v)).transpose(1, 2).reshape(B, N, C)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    @torch.jit.ignore
    def no_weight_decay(self):
        return {}

class SimABlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0.,
                 attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
                 num_tokens=196, eta=None):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = SimA(
            dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
            proj_drop=drop
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
                       drop=drop)

        self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
        self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)

    def forward(self, x, H, W):
        x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
        return x

class SimAVisionTransformer(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768,
                 depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
                 cls_attn_layers=2, use_pos=True, patch_proj='linear', eta=None, tokens_norm=False):

        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)

        self.patch_embed = ConvPatchEmbed(img_size=img_size, embed_dim=embed_dim,
                                          patch_size=patch_size)

        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [drop_path_rate for i in range(depth)]
        self.blocks = nn.ModuleList([
            SimABlock(
                dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
                norm_layer=norm_layer, num_tokens=num_patches, eta=eta)
            for i in range(depth)])

        self.norm = norm_layer(embed_dim)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
        self.use_pos = use_pos

        # Classifier head
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token', 'dist_token'}

    def forward_features(self, x):
        B, C, H, W = x.shape

        x, (Hp, Wp) = self.patch_embed(x)

        if self.use_pos:
            pos_encoding = self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
            x = x + pos_encoding

        x = self.pos_drop(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        for blk in self.blocks:
            x = blk(x, Hp, Wp)

        x = self.norm(x)[:, 0]
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)

        if self.training:
            return x, x
        else:
            return x
Mareeta26 commented 2 years ago

@soroush-abbasi Thank you! So, don't we need SimABlock class for ConvTransformer? What is the purpose of it? Can you please explain? My Input to the self attention module is a 5D tensor.Eg : 8,19,128,16,16. How shall I modify SimA class for such an input ?

soroush-abbasi commented 2 years ago

SimABlock is a regular transformer block which has both self-attention(SimA) and MLP layer. As long as you replace self-attention in your code with SimA you should be fine I guess. So you need to figure out which dimensions in your input is sequence (N) and which is Token dimensions (D) . Or if your features are after splitting to multi-head, you need to find the ordering of B (batch size), H (heads), D (dimension after splitting) and N (sequence length/ number of tokens). sometimes tokens are not flattens. For example, one can look at the image feature maps as a set of tokens with 2D shape. If you have 512x16x16 feature map, you can flatten the last two dimensions to get 512x256 tokens (D=512, N=256). I guess last two dimensions are feature maps of the image in your case, but I'm not sure. Unfortunately, I'm not familiar with ConTransformer.

Thanks!

Mareeta26 commented 2 years ago

Sure, thanks for the reply!!

Mareeta26 commented 2 years ago

@soroush-abbasi Can we use SimA if it's a masked self-attention?

soroush-abbasi commented 2 years ago

Hi,

It's a little complicated. So we normalize tokens in channel dimension before doing QKV dot product. Since we normalize tokens in the channel dimension, each token have effect on other tokens. Therefore, if you want to mask tokens, you need to apply masking before L1-normalization. Please let me know if you have more questions.

Thanks! Have a great day!