Guo-Stone / MambaMorph

MambaMorph: a Mamba-based Framework for Medical MR-CT Deformable Registration
MIT License
72 stars 10 forks source link

In other programs, the data used is 3D image data. If I want to use class MambaLayer(nn.Module): part of the code, is it necessary to divide the original data into patches first #8

Open lxy51 opened 4 months ago

lxy51 commented 4 months ago

In other programs, the data used is 3D image data. If I want to use class MambaLayer(nn.Module): part of the code, is it necessary to divide the original data into patches first

class MambaLayer(nn.Module): def init(self, dim, d_state=16, d_conv=4, expand=2, downsample=None): super().init() self.dim = dim self.norm = nn.LayerNorm(dim) self.mamba = Mamba( d_model=dim, # Model dimension d_model d_state=d_state, # SSM state expansion factor d_conv=d_conv, # Local convolution width expand=expand, # Block expansion factor )

patch merging layer

    if downsample is not None:
        self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm, reduce_factor=4)
    else:
        self.downsample = None

def forward(self, x, H, W, T):
    B, C = x.shape[0], x.shape[-1]
    assert C == self.dim
    x_norm = self.norm(x)
    if x_norm.dtype == torch.float16:
        x_norm = x_norm.type(torch.float32)
    x_mamba = self.mamba(x_norm)
    x = x_mamba.type(x.dtype)

    if self.downsample is not None:
        x_down = self.downsample(x, H, W, T)
        Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2
        return x, H, W, T, x_down, Wh, Ww, Wt
    else:
        return x, H, W, T, x, H, W, T