xmu-xiaoma666 / External-Attention-pytorch

🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐
MIT License
11.29k stars 1.92k forks source link

Maybe an error in mlp/mlp_mixer.py #23

Open SoulAttacker opened 3 years ago

SoulAttacker commented 3 years ago

Dear Author: Hello. I find a question in here, and after I read the paper, I find the skip-connection here is image

And the code here should be

class MixerBlock(nn.Module):
    def __init__(self,tokens_mlp_dim=16,channels_mlp_dim=1024,tokens_hidden_dim=32,channels_hidden_dim=1024):
        super().__init__()
        self.ln=nn.LayerNorm(channels_mlp_dim)
        self.tokens_mlp_block=MlpBlock(tokens_mlp_dim,mlp_dim=tokens_hidden_dim)
        self.channels_mlp_block=MlpBlock(channels_mlp_dim,mlp_dim=channels_hidden_dim)

    def forward(self,x):
        """
        x: (bs,tokens,channels)
        """
        ### tokens mixing
        y=self.ln(x)
        y=y.transpose(1,2) #(bs,channels,tokens)
        y=self.tokens_mlp_block(y) #(bs,channels,tokens)
        ### channels mixing
        y=y.transpose(1,2) #(bs,tokens,channels)
       # fixme: start
        out =x+y #(bs,tokens,channels)
        y=self.ln(out) #(bs,tokens,channels)
        y=out+self.channels_mlp_block(y) #(bs,tokens,channels)
       # fixme: end
        return y

Looking forward to your reply! Best wishes!

xmu-xiaoma666 commented 3 years ago
class MixerBlock(nn.Module):
    def __init__(self,tokens_mlp_dim=16,channels_mlp_dim=1024,tokens_hidden_dim=32,channels_hidden_dim=1024):
        super().__init__()
        self.ln=nn.LayerNorm(channels_mlp_dim)
        self.tokens_mlp_block=MlpBlock(tokens_mlp_dim,mlp_dim=tokens_hidden_dim)
        self.channels_mlp_block=MlpBlock(channels_mlp_dim,mlp_dim=channels_hidden_dim)

    def forward(self,x):
        """
        x: (bs,tokens,channels)
        """
        ### tokens mixing
        y=self.ln(x)
        y=y.transpose(1,2) #(bs,channels,tokens)
        y=self.tokens_mlp_block(y) #(bs,channels,tokens)
        ### channels mixing
        y=y.transpose(1,2) #(bs,tokens,channels)
       # fixme: start
        out =x+y #(bs,tokens,channels)
        y=self.ln(out) #(bs,tokens,channels)
        y=out+self.channels_mlp_block(y) #(bs,tokens,channels)
       # fixme: end
        return y

That is right, thank you for your suggestions.

SoulAttacker commented 3 years ago

My pleasure! 😁