MambaMorph: a Mamba-based Framework for Medical MR-CT Deformable Registration
MIT License
72
stars
10
forks
source link
In other programs, the data used is 3D image data. If I want to use class MambaLayer(nn.Module): part of the code, is it necessary to divide the original data into patches first #8
In other programs, the data used is 3D image data. If I want to use class MambaLayer(nn.Module): part of the code, is it necessary to divide the original data into patches first
class MambaLayer(nn.Module):
def init(self, dim, d_state=16, d_conv=4, expand=2, downsample=None):
super().init()
self.dim = dim
self.norm = nn.LayerNorm(dim)
self.mamba = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=d_conv, # Local convolution width
expand=expand, # Block expansion factor
)
patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=nn.LayerNorm, reduce_factor=4)
else:
self.downsample = None
def forward(self, x, H, W, T):
B, C = x.shape[0], x.shape[-1]
assert C == self.dim
x_norm = self.norm(x)
if x_norm.dtype == torch.float16:
x_norm = x_norm.type(torch.float32)
x_mamba = self.mamba(x_norm)
x = x_mamba.type(x.dtype)
if self.downsample is not None:
x_down = self.downsample(x, H, W, T)
Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2
return x, H, W, T, x_down, Wh, Ww, Wt
else:
return x, H, W, T, x, H, W, T
In other programs, the data used is 3D image data. If I want to use class MambaLayer(nn.Module): part of the code, is it necessary to divide the original data into patches first
class MambaLayer(nn.Module): def init(self, dim, d_state=16, d_conv=4, expand=2, downsample=None): super().init() self.dim = dim self.norm = nn.LayerNorm(dim) self.mamba = Mamba( d_model=dim, # Model dimension d_model d_state=d_state, # SSM state expansion factor d_conv=d_conv, # Local convolution width expand=expand, # Block expansion factor )
patch merging layer