MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
2.07k stars 124 forks source link

Questions on cross merge #159

Open ShixuanGu opened 5 months ago

ShixuanGu commented 5 months ago

Many thanks for open-source this great project! Here's a question on CrossMerge function:

In Pytorch version, it seems cross-merge takes 2D feature map as input, while in the paper it takes 1D feature from S6 block, does S6 block reshape the 1D feature back to (H,W)? If so, why is this necessary since in the CrossMerge function, the input is again reshaped into 1D for merge?

class CrossMerge(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) ys = ys.view(B, K, D, -1) ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)

MzeroMiko commented 5 months ago

Because we need to get the shape of the output of CrossMerge, while we do not want to input other parameters into this function. Meanwhile, the output of S6 block is only required to be viewed as (...,H, W), rather than reshaped, which means there's no actual modification on the data itself.