Open SoulAttacker opened 3 years ago
class MixerBlock(nn.Module): def __init__(self,tokens_mlp_dim=16,channels_mlp_dim=1024,tokens_hidden_dim=32,channels_hidden_dim=1024): super().__init__() self.ln=nn.LayerNorm(channels_mlp_dim) self.tokens_mlp_block=MlpBlock(tokens_mlp_dim,mlp_dim=tokens_hidden_dim) self.channels_mlp_block=MlpBlock(channels_mlp_dim,mlp_dim=channels_hidden_dim) def forward(self,x): """ x: (bs,tokens,channels) """ ### tokens mixing y=self.ln(x) y=y.transpose(1,2) #(bs,channels,tokens) y=self.tokens_mlp_block(y) #(bs,channels,tokens) ### channels mixing y=y.transpose(1,2) #(bs,tokens,channels) # fixme: start out =x+y #(bs,tokens,channels) y=self.ln(out) #(bs,tokens,channels) y=out+self.channels_mlp_block(y) #(bs,tokens,channels) # fixme: end return y
That is right, thank you for your suggestions.
My pleasure! 😁
Dear Author: Hello. I find a question in here, and after I read the paper, I find the skip-connection here is
And the code here should be
Looking forward to your reply! Best wishes!