LeapLabTHU / Agent-Attention

Official repository of Agent Attention (ECCV2024)
473 stars 35 forks source link

Thank you for your excellent work, I would like to ask where the A in your paper is in the code? #3

Closed xlnn closed 8 months ago

xlnn commented 8 months ago

Thank you for your excellent work, I would like to ask where the A in your paper is in the code? Is it the code below, but I didn't find the code about A? class AgentAttention of agent_transformer/models/agent_swin.py

tian-qing001 commented 8 months ago

Hi @xlnn, thank you for interest in our work. In the code, the variable agent_tokens corresponds to Agent Tokens ($A$), and the variable agent_v corresponds to Agent Features ($V_A$).

xlnn commented 8 months ago

Hi @xlnn, thank you for interest in our work. In the code, the variable agent_tokens corresponds to Agent Tokens (A), and the variable agent_v corresponds to Agent Features (VA).

谢谢您的回答,但是我还有一个问题:我将swin transformer的注意力换成了您的AgentAttention:

class AgentAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
                 shift_size=0, agent_num=49, **kwargs):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(dim=-1)
        self.shift_size = shift_size

        self.agent_num = agent_num
        self.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)
        self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
        self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))
        self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0], 1))
        self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1]))
        self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))
        self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))
        trunc_normal_(self.an_bias, std=.02)
        trunc_normal_(self.na_bias, std=.02)
        trunc_normal_(self.ah_bias, std=.02)
        trunc_normal_(self.aw_bias, std=.02)
        trunc_normal_(self.ha_bias, std=.02)
        trunc_normal_(self.wa_bias, std=.02)
        pool_size = int(agent_num ** 0.5)
        self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        b, n, c = x.shape
        h = int(n ** 0.5)
        w = int(n ** 0.5)
        num_heads = self.num_heads
        head_dim = c // num_heads
        qkv = self.qkv(x).reshape(b, n, 3, c).permute(2, 0, 1, 3)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)
        # q, k, v: b, n, c

        agent_tokens = self.pool(q.reshape(b, h, w, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)
        q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
        agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)

        position_bias1 = nn.functional.interpolate(self.an_bias, size=self.window_size, mode='bilinear')
        position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
        position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)
        position_bias = position_bias1 + position_bias2
        agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)
        agent_attn = self.attn_drop(agent_attn)
        agent_v = agent_attn @ v

        agent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')
        agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)
        agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)
        agent_bias = agent_bias1 + agent_bias2
        q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)
        q_attn = self.attn_drop(q_attn)
        x = q_attn @ agent_v

        x = x.transpose(1, 2).reshape(b, n, c)
        v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2)
        x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)

        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'

我使用了预训练权重:/home/class1/work/modify/G/checkpoints/swin_tiny_patch4_window7_224.pth https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_tiny_patch4_window7_224.pth 数据集使用了coco格式的数据,替换成您的AgentAttention后,发生了如下错误:

python-BaseException
Traceback (most recent call last):
  File "/home/class1/.pycharm_helpers/pydev/pydevd.py", line 1491, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/class1/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/class1/work/modify/G/tools/train.py", line 276, in <module>
    main()
  File "/home/class1/work/modify/G/tools/train.py", line 239, in main
    model.init_weights()
  File "/home/class1/.conda/envs/mm100/lib/python3.7/site-packages/mmcv/runner/base_module.py", line 117, in init_weights
    m.init_weights()
  File "/home/class1/work/modify/G/mmdet/models/backbones/swin_test.py", line 1296, in init_weights
    table_current = self.state_dict()[table_key]
KeyError: 'stages.0.blocks.0.attn.w_msa.relative_position_bias_table'

您知道如何解决吗?谢谢您!

tian-qing001 commented 8 months ago

I appreciate your efforts with our Agent-Swin model. However, attempting to load weights meant for Swin will naturally lead to a mismatch. To ensure compatibility, please use the pretrained weights that we provide.

xlnn commented 8 months ago

I appreciate your efforts with our Agent-Swin model. However, attempting to load weights meant for Swin will naturally lead to a mismatch. To ensure compatibility, please use the pretrained weights that we provide.

谢谢,我使用的是mmdetection的目标检测,您觉得我应该使用您的哪一个预训练权重呢?谢谢

tian-qing001 commented 8 months ago

You can find the pretrained weights on ImageNet here for general downstream task training. Additionally, specific pretrained weights for the object detection model are available here.

xlnn commented 8 months ago

You can find the pretrained weights on ImageNet here for general downstream task training. Additionally, specific pretrained weights for the object detection model are available here.

感谢您的回复! 这个downstream/detection/mmdet/models/backbones/agent_swin.py的agen_swin.py是不是代表在swin transformer上使用agent注意力?期待您的回复,谢谢!

tian-qing001 commented 8 months ago

Yes. This is the relevant code for object detection using Agent-Swin.