Amshaker / SwiftFormer

[ICCV'23] Official repository of paper SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications
247 stars 25 forks source link

EfficientAdditiveAttnetion in QKV interactive perspective. #8

Open lartpang opened 1 year ago

lartpang commented 1 year ago

Although the concept of "value" does not appear in the paper description and code implementation, it is actually very similar to the interaction form in MobileVit-V2.

As shown below, I have commented and organized the author's code.

As we can see, this is actually implicitly incorporating the interaction of Q and K into Q's own transformation. The "key" in the code is more like "value".

# https://github.com/Amshaker/SwiftFormer/blob/cd1f854e59f9e010279f8ff657a991d71ed9f13f/models/swiftformer.py#L141C1-L181C19
class EfficientAdditiveAttnetion(nn.Module):
    """
    Efficient Additive Attention module for SwiftFormer.
    Input: tensor in shape [B, N, D]
    Output: tensor in shape [B, N, D]
    """
    def __init__(self, in_dims=512, token_dim=256, num_heads=2):
        super().__init__()
        self.to_query = nn.Linear(in_dims, token_dim * num_heads)
        self.to_key = nn.Linear(in_dims, token_dim * num_heads)
        self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1))
        self.scale_factor = token_dim ** -0.5
        self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads)
        self.final = nn.Linear(token_dim * num_heads, token_dim)

    def forward(self, x):
        query = self.to_query(x)
        query = torch.nn.functional.normalize(query, dim=-1) #BxNxD      

        # convert query to the context vector
        query_weight = query @ self.w_g # BxNx1 (BxNxD @ Dx1)       
        A = query_weight * self.scale_factor # BxNx1
        A = torch.nn.functional.normalize(A, dim=1) # BxNx1        

        # similar to the Interaction of query and key in MobileVit-V2, 
        # and here A can be seen as "query" and query as "key"
        G = torch.sum(A * query, dim=1) # BxD
        G = einops.repeat(G, "b d -> b repeat d", repeat=key.shape[1]) # BxNxD

        key = self.to_key(x)
        key = torch.nn.functional.normalize(key, dim=-1) #BxNxD
        # here key can be seen as "value"
        out = self.Proj(G * key) + query #BxNxD
        return self.final(out) # BxNxD
Amshaker commented 1 year ago

Hi @lartpang, Thank you for your insights. SwiftFormer and MobileViT2 are similar in computing the interactions somehow, we already shown that in the attention comparison's figure. However, there are two major differences:

(1) We are built over Additive Attention, where you have learnable weights to learn where to attend "self.w_g". There is no learnable weights inside the linear attention of MobileViT2.

(2) We eliminate the need of a third interaction "We called it in the paper KV interactions". In MobileViT-2, they share the attention weights "context vector" by using a third branch "V". In our case, we revise this interaction and replace it by linear transformation and "Skip Connection" with the Q matrix. The skip connection acts as sharing the global context weights with the input 'Q', instead of having a third branch.

To summarize, there is common factor between them and we already showed that in the attention comparison's figure, but there are two major differences.

I hope it is clear now.

Best regards, Abdelrahman.