Open MoetezKd opened 7 months ago
That may depend on what version of code you are using. I have not tested the latency and memory usage in unet with the latest version of code, but in classification, VMamba has improved a lot. Also, VMamba will be relatively faster in bigger images compared to transformer.
Time is tested on 1xA100 with batch_size 128; the config file is vssm1/vssm_tiny_224_0220.yaml
. GPU memory is adopted from the log.
The experiments (arXiv 2401.10166) done before #20240119 used mamba-ssm + group-parallel
. The experiments done since #20240201 use sscore + fused cross scan + fused cross merge
. We plan to use ssoflex + fused cross scan + fused cross merge + input16output32
in the future.
name | GPU Memory | time (s/10iter) |
---|---|---|
mamba-ssm + sequence scan | 25927M | 0.6585s |
mamba-ssm + group parallel |
25672M |
0.4860s |
mamba-ssm + float16 | 20439M | 0.4195s |
mamba-ssm + fused cross scan | 25675M | 0.4820s |
mamba-ssm + fused cross scan + fused cross merge | 25596M | 0.4020s |
sscore + fused cross scan + fused cross merge |
24984M |
0.3930s |
sscore + fused cross scan + fused cross merge + forward nrow | 24984M | 0.4090s |
sscore + fused cross scan + fused cross merge + backward nrow | 24984M | 0.4490s |
sscore + fused cross scan + fused cross merge + forward nrow + backward nrow | 24984M | 0.4640s |
ssoflex + fused cross scan + fused cross merge | 24986M | 0.3940s |
ssoflex + fused cross scan + fused cross merge + input fp16 + output fp32 |
19842M |
0.3650s |
Thansk, i updated to the recetn version however when i load the pretrained it says this
Failed loading checkpoint form net/classification/vssmtiny_dp02_ckpt_epoch_258.pth: Error(s) in loading state_dict for Backbone_VSSM: size mismatch for patch_embed.0.weight: copying a param with shape torch.Size([48, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([96, 3, 4, 4]). size mismatch for patch_embed.0.bias: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([96]). size mismatch for patch_embed.2.weight: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([96]). size mismatch for patch_embed.2.bias: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([96]). size mismatch for layers.0.blocks.0.op.x_proj_weight: copying a param with shape torch.Size([4, 8, 192]) from checkpoint, the shape in current model is torch.Size([4, 38, 192]). size mismatch for layers.0.blocks.0.op.A_logs: copying a param with shape torch.Size([768, 1]) from checkpoint, the shape in current model is torch.Size([768, 16]). size mismatch for layers.0.blocks.1.op.x_proj_weight: copying a param with shape torch.Size([4, 8, 192]) from checkpoint, the shape in current model is torch.Size([4, 38, 192]). size mismatch for layers.0.blocks.1.op.A_logs: copying a param with shape torch.Size([768, 1]) from checkpoint, the shape in current model is torch.Size([768, 16]). size mismatch for layers.0.downsample.1.weight: copying a param with shape torch.Size([192, 96, 3, 3]) from checkpoint, the shape in current model is torch.Size([192, 96, 2, 2]). size mismatch for layers.1.blocks.0.op.x_proj_weight: copying a param with shape torch.Size([4, 14, 384]) from checkpoint, the shape in current model is torch.Size([4, 44, 384]). size mismatch for layers.1.blocks.0.op.A_logs: copying a param with shape torch.Size([1536, 1]) from checkpoint, the shape in current model is torch.Size([1536, 16]). size mismatch for layers.1.blocks.1.op.x_proj_weight: copying a param with shape torch.Size([4, 14, 384]) from checkpoint, the shape in current model is torch.Size([4, 44, 384]). size mismatch for layers.1.blocks.1.op.A_logs: copying a param with shape torch.Size([1536, 1]) from checkpoint, the shape in current model is torch.Size([1536, 16]). size mismatch for layers.1.downsample.1.weight: copying a param with shape torch.Size([384, 192, 3, 3]) from checkpoint, the shape in current model is torch.Size([384, 192, 2, 2]). size mismatch for layers.2.blocks.0.op.x_proj_weight: copying a param with shape torch.Size([4, 26, 768]) from checkpoint, the shape in current model is torch.Size([4, 56, 768]). size mismatch for layers.2.blocks.0.op.A_logs: copying a param with shape torch.Size([3072, 1]) from checkpoint, the shape in current model is torch.Size([3072, 16]). size mismatch for layers.2.blocks.1.op.x_proj_weight: copying a param with shape torch.Size([4, 26, 768]) from checkpoint, the shape in current model is torch.Size([4, 56, 768]). size mismatch for layers.2.blocks.1.op.A_logs: copying a param with shape torch.Size([3072, 1]) from checkpoint, the shape in current model is torch.Size([3072, 16]). size mismatch for layers.2.blocks.2.op.x_proj_weight: copying a param with shape torch.Size([4, 26, 768]) from checkpoint, the shape in current model is torch.Size([4, 56, 768]). size mismatch for layers.2.blocks.2.op.A_logs: copying a param with shape torch.Size([3072, 1]) from checkpoint, the shape in current model is torch.Size([3072, 16]). size mismatch for layers.2.blocks.3.op.x_proj_weight: copying a param with shape torch.Size([4, 26, 768]) from checkpoint, the shape in current model is torch.Size([4, 56, 768]). size mismatch for layers.2.blocks.3.op.A_logs: copying a param with shape torch.Size([3072, 1]) from checkpoint, the shape in current model is torch.Size([3072, 16]). size mismatch for layers.2.downsample.1.weight: copying a param with shape torch.Size([768, 384, 3, 3]) from checkpoint, the shape in current model is torch.Size([768, 384, 2, 2]). size mismatch for layers.3.blocks.0.op.x_proj_weight: copying a param with shape torch.Size([4, 50, 1536]) from checkpoint, the shape in current model is torch.Size([4, 80, 1536]). size mismatch for layers.3.blocks.0.op.A_logs: copying a param with shape torch.Size([6144, 1]) from checkpoint, the shape in current model is torch.Size([6144, 16]). size mismatch for layers.3.blocks.1.op.x_proj_weight: copying a param with shape torch.Size([4, 50, 1536]) from checkpoint, the shape in current model is torch.Size([4, 80, 1536]). size mismatch for layers.3.blocks.1.op.A_logs: copying a param with shape torch.Size([6144, 1]) from checkpoint, the shape in current model is torch.Size([6144, 16]).
I also noticed this when loading pretrained _IncompatibleKeys(missing_keys=['layers.0.blocks.0.norm2.weight', 'layers.0.blocks.0.norm2.bias', 'layers.0.blocks.0.mlp.fc1.weight', 'layers.0.blocks.0.mlp.fc1.bias', 'layers.0.blocks.0.mlp.fc2.weight', 'layers.0.blocks.0.mlp.fc2.bias', 'layers.0.blocks.1.norm2.weight', 'layers.0.blocks.1.norm2.bias', 'layers.0.blocks.1.mlp.fc1.weight', 'layers.0.blocks.1.mlp.fc1.bias', 'layers.0.blocks.1.mlp.fc2.weight', 'layers.0.blocks.1.mlp.fc2.bias', 'layers.0.downsample.1.weight', 'layers.0.downsample.1.bias', 'layers.0.downsample.3.weight', 'layers.0.downsample.3.bias', 'layers.1.blocks.0.norm2.weight', 'layers.1.blocks.0.norm2.bias', 'layers.1.blocks.0.mlp.fc1.weight', 'layers.1.blocks.0.mlp.fc1.bias', 'layers.1.blocks.0.mlp.fc2.weight', 'layers.1.blocks.0.mlp.fc2.bias', 'layers.1.blocks.1.norm2.weight', 'layers.1.blocks.1.norm2.bias', 'layers.1.blocks.1.mlp.fc1.weight', 'layers.1.blocks.1.mlp.fc1.bias', 'layers.1.blocks.1.mlp.fc2.weight', 'layers.1.blocks.1.mlp.fc2.bias', 'layers.1.downsample.1.weight', 'layers.1.downsample.1.bias', 'layers.1.downsample.3.weight', 'layers.1.downsample.3.bias', 'layers.2.blocks.0.norm2.weight', 'layers.2.blocks.0.norm2.bias', 'layers.2.blocks.0.mlp.fc1.weight', 'layers.2.blocks.0.mlp.fc1.bias', 'layers.2.blocks.0.mlp.fc2.weight', 'layers.2.blocks.0.mlp.fc2.bias', 'layers.2.blocks.1.norm2.weight', 'layers.2.blocks.1.norm2.bias', 'layers.2.blocks.1.mlp.fc1.weight', 'layers.2.blocks.1.mlp.fc1.bias', 'layers.2.blocks.1.mlp.fc2.weight', 'layers.2.blocks.1.mlp.fc2.bias', 'layers.2.blocks.2.norm2.weight', 'layers.2.blocks.2.norm2.bias', 'layers.2.blocks.2.mlp.fc1.weight', 'layers.2.blocks.2.mlp.fc1.bias', 'layers.2.blocks.2.mlp.fc2.weight', 'layers.2.blocks.2.mlp.fc2.bias', 'layers.2.blocks.3.norm2.weight', 'layers.2.blocks.3.norm2.bias', 'layers.2.blocks.3.mlp.fc1.weight', 'layers.2.blocks.3.mlp.fc1.bias', 'layers.2.blocks.3.mlp.fc2.weight', 'layers.2.blocks.3.mlp.fc2.bias', 'layers.2.blocks.4.norm2.weight', 'layers.2.blocks.4.norm2.bias', 'layers.2.blocks.4.mlp.fc1.weight', 'layers.2.blocks.4.mlp.fc1.bias', 'layers.2.blocks.4.mlp.fc2.weight', 'layers.2.blocks.4.mlp.fc2.bias', 'layers.2.blocks.5.norm2.weight', 'layers.2.blocks.5.norm2.bias', 'layers.2.blocks.5.mlp.fc1.weight', 'layers.2.blocks.5.mlp.fc1.bias', 'layers.2.blocks.5.mlp.fc2.weight', 'layers.2.blocks.5.mlp.fc2.bias', 'layers.2.blocks.6.norm2.weight', 'layers.2.blocks.6.norm2.bias', 'layers.2.blocks.6.mlp.fc1.weight', 'layers.2.blocks.6.mlp.fc1.bias', 'layers.2.blocks.6.mlp.fc2.weight', 'layers.2.blocks.6.mlp.fc2.bias', 'layers.2.blocks.7.norm2.weight', 'layers.2.blocks.7.norm2.bias', 'layers.2.blocks.7.mlp.fc1.weight', 'layers.2.blocks.7.mlp.fc1.bias', 'layers.2.blocks.7.mlp.fc2.weight', 'layers.2.blocks.7.mlp.fc2.bias', 'layers.2.blocks.8.norm2.weight', 'layers.2.blocks.8.norm2.bias', 'layers.2.blocks.8.mlp.fc1.weight', 'layers.2.blocks.8.mlp.fc1.bias', 'layers.2.blocks.8.mlp.fc2.weight', 'layers.2.blocks.8.mlp.fc2.bias', 'layers.2.downsample.1.weight', 'layers.2.downsample.1.bias', 'layers.2.downsample.3.weight', 'layers.2.downsample.3.bias', 'layers.3.blocks.0.norm2.weight', 'layers.3.blocks.0.norm2.bias', 'layers.3.blocks.0.mlp.fc1.weight', 'layers.3.blocks.0.mlp.fc1.bias', 'layers.3.blocks.0.mlp.fc2.weight', 'layers.3.blocks.0.mlp.fc2.bias', 'layers.3.blocks.1.norm2.weight', 'layers.3.blocks.1.norm2.bias', 'layers.3.blocks.1.mlp.fc1.weight', 'layers.3.blocks.1.mlp.fc1.bias', 'layers.3.blocks.1.mlp.fc2.weight', 'layers.3.blocks.1.mlp.fc2.bias', 'outnorm0.weight', 'outnorm0.bias', 'outnorm1.weight', 'outnorm1.bias', 'outnorm2.weight', 'outnorm2.bias', 'outnorm3.weight', 'outnorm3.bias'], unexpected_keys=['classifier.norm.weight', 'classifier.norm.bias', 'classifier.head.weight', 'classifier.head.bias', 'layers.0.downsample.reduction.weight', 'layers.0.downsample.norm.weight', 'layers.0.downsample.norm.bias', 'layers.1.downsample.reduction.weight', 'layers.1.downsample.norm.weight', 'layers.1.downsample.norm.bias', 'layers.2.blocks.9.norm.weight', 'layers.2.blocks.9.norm.bias', 'layers.2.blocks.9.op.x_proj_weight', 'layers.2.blocks.9.op.dt_projs_weight', 'layers.2.blocks.9.op.dt_projs_bias', 'layers.2.blocks.9.op.A_logs', 'layers.2.blocks.9.op.Ds', 'layers.2.blocks.9.op.in_proj.weight', 'layers.2.blocks.9.op.conv2d.weight', 'layers.2.blocks.9.op.conv2d.bias', 'layers.2.blocks.9.op.out_norm.weight', 'layers.2.blocks.9.op.out_norm.bias', 'layers.2.blocks.9.op.out_proj.weight', 'layers.2.blocks.10.norm.weight', 'layers.2.blocks.10.norm.bias', 'layers.2.blocks.10.op.x_proj_weight', 'layers.2.blocks.10.op.dt_projs_weight', 'layers.2.blocks.10.op.dt_projs_bias', 'layers.2.blocks.10.op.A_logs', 'layers.2.blocks.10.op.Ds', 'layers.2.blocks.10.op.in_proj.weight', 'layers.2.blocks.10.op.conv2d.weight', 'layers.2.blocks.10.op.conv2d.bias', 'layers.2.blocks.10.op.out_norm.weight', 'layers.2.blocks.10.op.out_norm.bias', 'layers.2.blocks.10.op.out_proj.weight', 'layers.2.blocks.11.norm.weight', 'layers.2.blocks.11.norm.bias', 'layers.2.blocks.11.op.x_proj_weight', 'layers.2.blocks.11.op.dt_projs_weight', 'layers.2.blocks.11.op.dt_projs_bias', 'layers.2.blocks.11.op.A_logs', 'layers.2.blocks.11.op.Ds', 'layers.2.blocks.11.op.in_proj.weight', 'layers.2.blocks.11.op.conv2d.weight', 'layers.2.blocks.11.op.conv2d.bias', 'layers.2.blocks.11.op.out_norm.weight', 'layers.2.blocks.11.op.out_norm.bias', 'layers.2.blocks.11.op.out_proj.weight', 'layers.2.blocks.12.norm.weight', 'layers.2.blocks.12.norm.bias', 'layers.2.blocks.12.op.x_proj_weight', 'layers.2.blocks.12.op.dt_projs_weight', 'layers.2.blocks.12.op.dt_projs_bias', 'layers.2.blocks.12.op.A_logs', 'layers.2.blocks.12.op.Ds', 'layers.2.blocks.12.op.in_proj.weight', 'layers.2.blocks.12.op.conv2d.weight', 'layers.2.blocks.12.op.conv2d.bias', 'layers.2.blocks.12.op.out_norm.weight', 'layers.2.blocks.12.op.out_norm.bias', 'layers.2.blocks.12.op.out_proj.weight', 'layers.2.blocks.13.norm.weight', 'layers.2.blocks.13.norm.bias', 'layers.2.blocks.13.op.x_proj_weight', 'layers.2.blocks.13.op.dt_projs_weight', 'layers.2.blocks.13.op.dt_projs_bias', 'layers.2.blocks.13.op.A_logs', 'layers.2.blocks.13.op.Ds', 'layers.2.blocks.13.op.in_proj.weight', 'layers.2.blocks.13.op.conv2d.weight', 'layers.2.blocks.13.op.conv2d.bias', 'layers.2.blocks.13.op.out_norm.weight', 'layers.2.blocks.13.op.out_norm.bias', 'layers.2.blocks.13.op.out_proj.weight', 'layers.2.blocks.14.norm.weight', 'layers.2.blocks.14.norm.bias', 'layers.2.blocks.14.op.x_proj_weight', 'layers.2.blocks.14.op.dt_projs_weight', 'layers.2.blocks.14.op.dt_projs_bias', 'layers.2.blocks.14.op.A_logs', 'layers.2.blocks.14.op.Ds', 'layers.2.blocks.14.op.in_proj.weight', 'layers.2.blocks.14.op.conv2d.weight', 'layers.2.blocks.14.op.conv2d.bias', 'layers.2.blocks.14.op.out_norm.weight', 'layers.2.blocks.14.op.out_norm.bias', 'layers.2.blocks.14.op.out_proj.weight', 'layers.2.blocks.15.norm.weight', 'layers.2.blocks.15.norm.bias', 'layers.2.blocks.15.op.x_proj_weight', 'layers.2.blocks.15.op.dt_projs_weight', 'layers.2.blocks.15.op.dt_projs_bias', 'layers.2.blocks.15.op.A_logs', 'layers.2.blocks.15.op.Ds', 'layers.2.blocks.15.op.in_proj.weight', 'layers.2.blocks.15.op.conv2d.weight', 'layers.2.blocks.15.op.conv2d.bias', 'layers.2.blocks.15.op.out_norm.weight', 'layers.2.blocks.15.op.out_norm.bias', 'layers.2.blocks.15.op.out_proj.weight', 'layers.2.blocks.16.norm.weight', 'layers.2.blocks.16.norm.bias', 'layers.2.blocks.16.op.x_proj_weight', 'layers.2.blocks.16.op.dt_projs_weight', 'layers.2.blocks.16.op.dt_projs_bias', 'layers.2.blocks.16.op.A_logs', 'layers.2.blocks.16.op.Ds', 'layers.2.blocks.16.op.in_proj.weight', 'layers.2.blocks.16.op.conv2d.weight', 'layers.2.blocks.16.op.conv2d.bias', 'layers.2.blocks.16.op.out_norm.weight', 'layers.2.blocks.16.op.out_norm.bias', 'layers.2.blocks.16.op.out_proj.weight', 'layers.2.blocks.17.norm.weight', 'layers.2.blocks.17.norm.bias', 'layers.2.blocks.17.op.x_proj_weight', 'layers.2.blocks.17.op.dt_projs_weight', 'layers.2.blocks.17.op.dt_projs_bias', 'layers.2.blocks.17.op.A_logs', 'layers.2.blocks.17.op.Ds', 'layers.2.blocks.17.op.in_proj.weight', 'layers.2.blocks.17.op.conv2d.weight', 'layers.2.blocks.17.op.conv2d.bias', 'layers.2.blocks.17.op.out_norm.weight', 'layers.2.blocks.17.op.out_norm.bias', 'layers.2.blocks.17.op.out_proj.weight', 'layers.2.blocks.18.norm.weight', 'layers.2.blocks.18.norm.bias', 'layers.2.blocks.18.op.x_proj_weight', 'layers.2.blocks.18.op.dt_projs_weight', 'layers.2.blocks.18.op.dt_projs_bias', 'layers.2.blocks.18.op.A_logs', 'layers.2.blocks.18.op.Ds', 'layers.2.blocks.18.op.in_proj.weight', 'layers.2.blocks.18.op.conv2d.weight', 'layers.2.blocks.18.op.conv2d.bias', 'layers.2.blocks.18.op.out_norm.weight', 'layers.2.blocks.18.op.out_norm.bias', 'layers.2.blocks.18.op.out_proj.weight', 'layers.2.blocks.19.norm.weight', 'layers.2.blocks.19.norm.bias', 'layers.2.blocks.19.op.x_proj_weight', 'layers.2.blocks.19.op.dt_projs_weight', 'layers.2.blocks.19.op.dt_projs_bias', 'layers.2.blocks.19.op.A_logs', 'layers.2.blocks.19.op.Ds', 'layers.2.blocks.19.op.in_proj.weight', 'layers.2.blocks.19.op.conv2d.weight', 'layers.2.blocks.19.op.conv2d.bias', 'layers.2.blocks.19.op.out_norm.weight', 'layers.2.blocks.19.op.out_norm.bias', 'layers.2.blocks.19.op.out_proj.weight', 'layers.2.blocks.20.norm.weight', 'layers.2.blocks.20.norm.bias', 'layers.2.blocks.20.op.x_proj_weight', 'layers.2.blocks.20.op.dt_projs_weight', 'layers.2.blocks.20.op.dt_projs_bias', 'layers.2.blocks.20.op.A_logs', 'layers.2.blocks.20.op.Ds', 'layers.2.blocks.20.op.in_proj.weight', 'layers.2.blocks.20.op.conv2d.weight', 'layers.2.blocks.20.op.conv2d.bias', 'layers.2.blocks.20.op.out_norm.weight', 'layers.2.blocks.20.op.out_norm.bias', 'layers.2.blocks.20.op.out_proj.weight', 'layers.2.blocks.21.norm.weight', 'layers.2.blocks.21.norm.bias', 'layers.2.blocks.21.op.x_proj_weight', 'layers.2.blocks.21.op.dt_projs_weight', 'layers.2.blocks.21.op.dt_projs_bias', 'layers.2.blocks.21.op.A_logs', 'layers.2.blocks.21.op.Ds', 'layers.2.blocks.21.op.in_proj.weight', 'layers.2.blocks.21.op.conv2d.weight', 'layers.2.blocks.21.op.conv2d.bias', 'layers.2.blocks.21.op.out_norm.weight', 'layers.2.blocks.21.op.out_norm.bias', 'layers.2.blocks.21.op.out_proj.weight', 'layers.2.blocks.22.norm.weight', 'layers.2.blocks.22.norm.bias', 'layers.2.blocks.22.op.x_proj_weight', 'layers.2.blocks.22.op.dt_projs_weight', 'layers.2.blocks.22.op.dt_projs_bias', 'layers.2.blocks.22.op.A_logs', 'layers.2.blocks.22.op.Ds', 'layers.2.blocks.22.op.in_proj.weight', 'layers.2.blocks.22.op.conv2d.weight', 'layers.2.blocks.22.op.conv2d.bias', 'layers.2.blocks.22.op.out_norm.weight', 'layers.2.blocks.22.op.out_norm.bias', 'layers.2.blocks.22.op.out_proj.weight', 'layers.2.blocks.23.norm.weight', 'layers.2.blocks.23.norm.bias', 'layers.2.blocks.23.op.x_proj_weight', 'layers.2.blocks.23.op.dt_projs_weight', 'layers.2.blocks.23.op.dt_projs_bias', 'layers.2.blocks.23.op.A_logs', 'layers.2.blocks.23.op.Ds', 'layers.2.blocks.23.op.in_proj.weight', 'layers.2.blocks.23.op.conv2d.weight', 'layers.2.blocks.23.op.conv2d.bias', 'layers.2.blocks.23.op.out_norm.weight', 'layers.2.blocks.23.op.out_norm.bias', 'layers.2.blocks.23.op.out_proj.weight', 'layers.2.blocks.24.norm.weight', 'layers.2.blocks.24.norm.bias', 'layers.2.blocks.24.op.x_proj_weight', 'layers.2.blocks.24.op.dt_projs_weight', 'layers.2.blocks.24.op.dt_projs_bias', 'layers.2.blocks.24.op.A_logs', 'layers.2.blocks.24.op.Ds', 'layers.2.blocks.24.op.in_proj.weight', 'layers.2.blocks.24.op.conv2d.weight', 'layers.2.blocks.24.op.conv2d.bias', 'layers.2.blocks.24.op.out_norm.weight', 'layers.2.blocks.24.op.out_norm.bias', 'layers.2.blocks.24.op.out_proj.weight', 'layers.2.blocks.25.norm.weight', 'layers.2.blocks.25.norm.bias', 'layers.2.blocks.25.op.x_proj_weight', 'layers.2.blocks.25.op.dt_projs_weight', 'layers.2.blocks.25.op.dt_projs_bias', 'layers.2.blocks.25.op.A_logs', 'layers.2.blocks.25.op.Ds', 'layers.2.blocks.25.op.in_proj.weight', 'layers.2.blocks.25.op.conv2d.weight', 'layers.2.blocks.25.op.conv2d.bias', 'layers.2.blocks.25.op.out_norm.weight', 'layers.2.blocks.25.op.out_norm.bias', 'layers.2.blocks.25.op.out_proj.weight', 'layers.2.blocks.26.norm.weight', 'layers.2.blocks.26.norm.bias', 'layers.2.blocks.26.op.x_proj_weight', 'layers.2.blocks.26.op.dt_projs_weight', 'layers.2.blocks.26.op.dt_projs_bias', 'layers.2.blocks.26.op.A_logs', 'layers.2.blocks.26.op.Ds', 'layers.2.blocks.26.op.in_proj.weight', 'layers.2.blocks.26.op.conv2d.weight', 'layers.2.blocks.26.op.conv2d.bias', 'layers.2.blocks.26.op.out_norm.weight', 'layers.2.blocks.26.op.out_norm.bias', 'layers.2.blocks.26.op.out_proj.weight', 'layers.2.downsample.reduction.weight', 'layers.2.downsample.norm.weight', 'layers.2.downsample.norm.bias'])
Are you using the right config for that checkpoint? it seems that you are loading a state dict of tiny model into small or base model.
this is what i m using to load it now although no message of mismatch showing rn but don t know hot to integrate the config file
build = import_abspy( "models", os.path.join(os.path.dirname(os.path.abspath(file)), "./classification/"), ) def MambaBlock(pretrained): mod=torch.load(pretrained,map_location='cpu') Backbone_VSSM: nn.Module = build.vmamba.Backbone_VSSM() Backbone_VSSM._load_from_state_dict(state_dict=mod['model'],prefix='', local_metadata=None, strict=True, missing_keys=[], unexpected_keys=[], error_msgs=[]) return Backbone_VSSM
I think the code in detection/model.py
may helps.
Hi, Thanks for the contribution. I was working on UNET like architecture with VMamba as encoder but it turns out that its latency and memory usage is more than a transformer encoder with the same number of parameters approximately. Do you think this is due to it being underoptimized and early stage?