LeapLabTHU / Agent-Attention

Official repository of Agent Attention (ECCV2024)
473 stars 35 forks source link

Some others approach to design agent_tokens? #30

Closed AlfieDiCaprio closed 6 months ago

AlfieDiCaprio commented 6 months ago

Hi, thank you for your excellent work! I would like to migrate this work to the Transformer model that solves the related combinatorial optimization problem ex: the traveling salesman problem. Since this type of problem does not involve operations such as pooling, DWC, Bias, etc., I have the following questions:

  1. Are there other design methods for Agent_tokens? As mentioned in the article, "set to a set of learnable parameters" can be used? But I don't quite understand how this approach should be designed in code.
  2. Is it similar to a plug-and-play module to integrate Agent Attention into other Transformers? But it seems to be somewhat difficult due to different tensor dimensions.

I would be extremely grateful for any advice you could provide, and thank you so much for sharing such great work!

tian-qing001 commented 6 months ago

Hi @AlfieDiCaprio, thank you for your interest in our work.

  1. There are many ways to obtain agent tokens. The core idea is downsmapling all N tokens to n agent tokens. To set agent tokens as a set of learnable parameters, you can refer to the following code:

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                 shift_size=0, agent_num=49, **kwargs):
    
        ... ...
    
        self.agent_tokens = nn.Parameter(torch.zeros(1, num_heads, agent_num, head_dim))
        trunc_normal_(self.agent_tokens , std=.02)
    
    def forward(self, x, mask=None):
        b, n, c = x.shape
        h = int(n ** 0.5)
        w = int(n ** 0.5)
        num_heads = self.num_heads
        head_dim = c // num_heads
        qkv = self.qkv(x).reshape(b, n, 3, c).permute(2, 0, 1, 3)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        # q, k, v: b, n, c
    
        q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        agent_tokens = agent_tokens.repeat(b,1, 1, 1)
    
        ... ...
    
        return x

    Please note that the code is incomplete.

  2. Agent Attention is a plug-and-play module and can be applied to various Transformer models using appropriate agent token designs.

AlfieDiCaprio commented 6 months ago

Thank you for your time and I'm grateful for your reply!