Hi, the Conv-SSM Block you draw in the paper has a Conv 2D 1x1 block when merging the outputs of the two branches.
But I see the code of MedMamba.py
def forward(self, input: torch.Tensor): input_left, input_right = input.chunk(2,dim=-1) x = self.drop_path(self.self_attention(self.ln_1(input_right))) input_left = input_left.permute(0,3,1,2).contiguous() input_left = self.conv33conv33conv11(input_left) input_left = input_left.permute(0,2,3,1).contiguous() output = torch.cat((input_left,x),dim=-1) output = channel_shuffle(output,groups=2) return output+input
You use channel_shuffle instead of self.finalconv11 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1, stride=1)
Could you please give an explanation about that?
Hi, the Conv-SSM Block you draw in the paper has a Conv 2D 1x1 block when merging the outputs of the two branches. But I see the code of MedMamba.py
def forward(self, input: torch.Tensor): input_left, input_right = input.chunk(2,dim=-1) x = self.drop_path(self.self_attention(self.ln_1(input_right))) input_left = input_left.permute(0,3,1,2).contiguous() input_left = self.conv33conv33conv11(input_left) input_left = input_left.permute(0,2,3,1).contiguous() output = torch.cat((input_left,x),dim=-1) output = channel_shuffle(output,groups=2) return output+input
You usechannel_shuffle
instead ofself.finalconv11 = nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=1, stride=1)
Could you please give an explanation about that?