LeapLabTHU / MLLA

Official repository of MLLA (NeurIPS 2024)
179 stars 6 forks source link

Confusing about shortcut #8

Closed kyrie-23 closed 3 months ago

kyrie-23 commented 3 months ago

Thanks for the amazing work! As far as I know, the MLLA adapts the shortcut mechanism from swin transformer (from code), however, the architecture in fig.3a and baseline in the table don't have the shortcut mechanism. The shortcut mechanism is not reflected either in sec 5.3, from my point of view, MLLA doesn't integrate shortcut because swin transformer already had. As a result, the baseline (swin transformer) is not consistent, it should have shortcut mechanism but treated as not. Please correct me if I was wrong! Thanks again for the amazing work!

tian-qing001 commented 3 months ago

Hi @kyrie-23, thanks for your interest, but it seems that you have some misunderstandings.

image

As showin in fig. 3a, Transformer, Mamba and MLLA both have shortcuts. But these shortcuts are not the "shortcuts" we discussed in sec 4.1, 4.2, 5.2.

The "shortcuts" in sec 4.1, 4.2, 5.2 specifically refer to $D\odot x_i$ in the selective SSM formula, which is not illustrated in fig. 3a since it is a part of the SSM operation.

Please refer to sec 4.1 of our paper.

kyrie-23 commented 3 months ago

Thanks for pointing it out. The shortcut I referred to is circled below, does it correspond to the droppath in MLLA block below? In this way, the shortcut design is not illustrated in the fig. 3, and you didn't bring it to MLLA design. Please correct me if I was wrong! image

def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, 
        x = x + self.cpe1(x.reshape(B, H, W, C).permute(0, 3, 1, 2)).flatten(2).permute(0, 2, 1)
        shortcut = x
        x = self.norm1(x)
        act_res = self.act(self.act_proj(x))
        x = self.in_proj(x).view(B, H, W, C)
        x = self.act(self.dwc(x.permute(0, 3, 1, 2))).permute(0, 2, 3, 1).view(B, L, C)
        # Linear Attention
        x = self.attn(x)
        x = self.out_proj(x * act_res)
        x = shortcut + self.drop_path(x)
        x = x + self.cpe2(x.reshape(B, H, W, C).permute(0, 3, 1, 2)).flatten(2).permute(0, 2, 1)
        # FFN
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

Thanks for the discussing!

tian-qing001 commented 3 months ago

@kyrie-23 The "shortcut" you referred to is:

def forward(self, x):
        ...
        act_res = self.act(self.act_proj(x))
        ...
        x = self.out_proj(x * act_res)
        ...
        return x

I think our code is easy to understand.