MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
2.1k stars 132 forks source link

Latency / memory consumption #56

Open MoetezKd opened 7 months ago

MoetezKd commented 7 months ago

Hi, Thanks for the contribution. I was working on UNET like architecture with VMamba as encoder but it turns out that its latency and memory usage is more than a transformer encoder with the same number of parameters approximately. Do you think this is due to it being underoptimized and early stage?

MzeroMiko commented 7 months ago

That may depend on what version of code you are using. I have not tested the latency and memory usage in unet with the latest version of code, but in classification, VMamba has improved a lot. Also, VMamba will be relatively faster in bigger images compared to transformer.


The History of Speed Up

Time is tested on 1xA100 with batch_size 128; the config file is vssm1/vssm_tiny_224_0220.yaml. GPU memory is adopted from the log.

The experiments (arXiv 2401.10166) done before #20240119 used mamba-ssm + group-parallel. The experiments done since #20240201 use sscore + fused cross scan + fused cross merge. We plan to use ssoflex + fused cross scan + fused cross merge + input16output32 in the future.

name GPU Memory time (s/10iter)
mamba-ssm + sequence scan 25927M 0.6585s
mamba-ssm + group parallel 25672M 0.4860s
mamba-ssm + float16 20439M 0.4195s
mamba-ssm + fused cross scan 25675M 0.4820s
mamba-ssm + fused cross scan + fused cross merge 25596M 0.4020s
sscore + fused cross scan + fused cross merge 24984M 0.3930s
sscore + fused cross scan + fused cross merge + forward nrow 24984M 0.4090s
sscore + fused cross scan + fused cross merge + backward nrow 24984M 0.4490s
sscore + fused cross scan + fused cross merge + forward nrow + backward nrow 24984M 0.4640s
ssoflex + fused cross scan + fused cross merge 24986M 0.3940s
ssoflex + fused cross scan + fused cross merge + input fp16 + output fp32 19842M 0.3650s
MoetezKd commented 7 months ago

Thansk, i updated to the recetn version however when i load the pretrained it says this Failed loading checkpoint form net/classification/vssmtiny_dp02_ckpt_epoch_258.pth: Error(s) in loading state_dict for Backbone_VSSM: size mismatch for patch_embed.0.weight: copying a param with shape torch.Size([48, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([96, 3, 4, 4]). size mismatch for patch_embed.0.bias: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([96]). size mismatch for patch_embed.2.weight: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([96]). size mismatch for patch_embed.2.bias: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([96]). size mismatch for layers.0.blocks.0.op.x_proj_weight: copying a param with shape torch.Size([4, 8, 192]) from checkpoint, the shape in current model is torch.Size([4, 38, 192]). size mismatch for layers.0.blocks.0.op.A_logs: copying a param with shape torch.Size([768, 1]) from checkpoint, the shape in current model is torch.Size([768, 16]). size mismatch for layers.0.blocks.1.op.x_proj_weight: copying a param with shape torch.Size([4, 8, 192]) from checkpoint, the shape in current model is torch.Size([4, 38, 192]). size mismatch for layers.0.blocks.1.op.A_logs: copying a param with shape torch.Size([768, 1]) from checkpoint, the shape in current model is torch.Size([768, 16]). size mismatch for layers.0.downsample.1.weight: copying a param with shape torch.Size([192, 96, 3, 3]) from checkpoint, the shape in current model is torch.Size([192, 96, 2, 2]). size mismatch for layers.1.blocks.0.op.x_proj_weight: copying a param with shape torch.Size([4, 14, 384]) from checkpoint, the shape in current model is torch.Size([4, 44, 384]). size mismatch for layers.1.blocks.0.op.A_logs: copying a param with shape torch.Size([1536, 1]) from checkpoint, the shape in current model is torch.Size([1536, 16]). size mismatch for layers.1.blocks.1.op.x_proj_weight: copying a param with shape torch.Size([4, 14, 384]) from checkpoint, the shape in current model is torch.Size([4, 44, 384]). size mismatch for layers.1.blocks.1.op.A_logs: copying a param with shape torch.Size([1536, 1]) from checkpoint, the shape in current model is torch.Size([1536, 16]). size mismatch for layers.1.downsample.1.weight: copying a param with shape torch.Size([384, 192, 3, 3]) from checkpoint, the shape in current model is torch.Size([384, 192, 2, 2]). size mismatch for layers.2.blocks.0.op.x_proj_weight: copying a param with shape torch.Size([4, 26, 768]) from checkpoint, the shape in current model is torch.Size([4, 56, 768]). size mismatch for layers.2.blocks.0.op.A_logs: copying a param with shape torch.Size([3072, 1]) from checkpoint, the shape in current model is torch.Size([3072, 16]). size mismatch for layers.2.blocks.1.op.x_proj_weight: copying a param with shape torch.Size([4, 26, 768]) from checkpoint, the shape in current model is torch.Size([4, 56, 768]). size mismatch for layers.2.blocks.1.op.A_logs: copying a param with shape torch.Size([3072, 1]) from checkpoint, the shape in current model is torch.Size([3072, 16]). size mismatch for layers.2.blocks.2.op.x_proj_weight: copying a param with shape torch.Size([4, 26, 768]) from checkpoint, the shape in current model is torch.Size([4, 56, 768]). size mismatch for layers.2.blocks.2.op.A_logs: copying a param with shape torch.Size([3072, 1]) from checkpoint, the shape in current model is torch.Size([3072, 16]). size mismatch for layers.2.blocks.3.op.x_proj_weight: copying a param with shape torch.Size([4, 26, 768]) from checkpoint, the shape in current model is torch.Size([4, 56, 768]). size mismatch for layers.2.blocks.3.op.A_logs: copying a param with shape torch.Size([3072, 1]) from checkpoint, the shape in current model is torch.Size([3072, 16]). size mismatch for layers.2.downsample.1.weight: copying a param with shape torch.Size([768, 384, 3, 3]) from checkpoint, the shape in current model is torch.Size([768, 384, 2, 2]). size mismatch for layers.3.blocks.0.op.x_proj_weight: copying a param with shape torch.Size([4, 50, 1536]) from checkpoint, the shape in current model is torch.Size([4, 80, 1536]). size mismatch for layers.3.blocks.0.op.A_logs: copying a param with shape torch.Size([6144, 1]) from checkpoint, the shape in current model is torch.Size([6144, 16]). size mismatch for layers.3.blocks.1.op.x_proj_weight: copying a param with shape torch.Size([4, 50, 1536]) from checkpoint, the shape in current model is torch.Size([4, 80, 1536]). size mismatch for layers.3.blocks.1.op.A_logs: copying a param with shape torch.Size([6144, 1]) from checkpoint, the shape in current model is torch.Size([6144, 16]).

MoetezKd commented 7 months ago

I also noticed this when loading pretrained _IncompatibleKeys(missing_keys=['layers.0.blocks.0.norm2.weight', 'layers.0.blocks.0.norm2.bias', 'layers.0.blocks.0.mlp.fc1.weight', 'layers.0.blocks.0.mlp.fc1.bias', 'layers.0.blocks.0.mlp.fc2.weight', 'layers.0.blocks.0.mlp.fc2.bias', 'layers.0.blocks.1.norm2.weight', 'layers.0.blocks.1.norm2.bias', 'layers.0.blocks.1.mlp.fc1.weight', 'layers.0.blocks.1.mlp.fc1.bias', 'layers.0.blocks.1.mlp.fc2.weight', 'layers.0.blocks.1.mlp.fc2.bias', 'layers.0.downsample.1.weight', 'layers.0.downsample.1.bias', 'layers.0.downsample.3.weight', 'layers.0.downsample.3.bias', 'layers.1.blocks.0.norm2.weight', 'layers.1.blocks.0.norm2.bias', 'layers.1.blocks.0.mlp.fc1.weight', 'layers.1.blocks.0.mlp.fc1.bias', 'layers.1.blocks.0.mlp.fc2.weight', 'layers.1.blocks.0.mlp.fc2.bias', 'layers.1.blocks.1.norm2.weight', 'layers.1.blocks.1.norm2.bias', 'layers.1.blocks.1.mlp.fc1.weight', 'layers.1.blocks.1.mlp.fc1.bias', 'layers.1.blocks.1.mlp.fc2.weight', 'layers.1.blocks.1.mlp.fc2.bias', 'layers.1.downsample.1.weight', 'layers.1.downsample.1.bias', 'layers.1.downsample.3.weight', 'layers.1.downsample.3.bias', 'layers.2.blocks.0.norm2.weight', 'layers.2.blocks.0.norm2.bias', 'layers.2.blocks.0.mlp.fc1.weight', 'layers.2.blocks.0.mlp.fc1.bias', 'layers.2.blocks.0.mlp.fc2.weight', 'layers.2.blocks.0.mlp.fc2.bias', 'layers.2.blocks.1.norm2.weight', 'layers.2.blocks.1.norm2.bias', 'layers.2.blocks.1.mlp.fc1.weight', 'layers.2.blocks.1.mlp.fc1.bias', 'layers.2.blocks.1.mlp.fc2.weight', 'layers.2.blocks.1.mlp.fc2.bias', 'layers.2.blocks.2.norm2.weight', 'layers.2.blocks.2.norm2.bias', 'layers.2.blocks.2.mlp.fc1.weight', 'layers.2.blocks.2.mlp.fc1.bias', 'layers.2.blocks.2.mlp.fc2.weight', 'layers.2.blocks.2.mlp.fc2.bias', 'layers.2.blocks.3.norm2.weight', 'layers.2.blocks.3.norm2.bias', 'layers.2.blocks.3.mlp.fc1.weight', 'layers.2.blocks.3.mlp.fc1.bias', 'layers.2.blocks.3.mlp.fc2.weight', 'layers.2.blocks.3.mlp.fc2.bias', 'layers.2.blocks.4.norm2.weight', 'layers.2.blocks.4.norm2.bias', 'layers.2.blocks.4.mlp.fc1.weight', 'layers.2.blocks.4.mlp.fc1.bias', 'layers.2.blocks.4.mlp.fc2.weight', 'layers.2.blocks.4.mlp.fc2.bias', 'layers.2.blocks.5.norm2.weight', 'layers.2.blocks.5.norm2.bias', 'layers.2.blocks.5.mlp.fc1.weight', 'layers.2.blocks.5.mlp.fc1.bias', 'layers.2.blocks.5.mlp.fc2.weight', 'layers.2.blocks.5.mlp.fc2.bias', 'layers.2.blocks.6.norm2.weight', 'layers.2.blocks.6.norm2.bias', 'layers.2.blocks.6.mlp.fc1.weight', 'layers.2.blocks.6.mlp.fc1.bias', 'layers.2.blocks.6.mlp.fc2.weight', 'layers.2.blocks.6.mlp.fc2.bias', 'layers.2.blocks.7.norm2.weight', 'layers.2.blocks.7.norm2.bias', 'layers.2.blocks.7.mlp.fc1.weight', 'layers.2.blocks.7.mlp.fc1.bias', 'layers.2.blocks.7.mlp.fc2.weight', 'layers.2.blocks.7.mlp.fc2.bias', 'layers.2.blocks.8.norm2.weight', 'layers.2.blocks.8.norm2.bias', 'layers.2.blocks.8.mlp.fc1.weight', 'layers.2.blocks.8.mlp.fc1.bias', 'layers.2.blocks.8.mlp.fc2.weight', 'layers.2.blocks.8.mlp.fc2.bias', 'layers.2.downsample.1.weight', 'layers.2.downsample.1.bias', 'layers.2.downsample.3.weight', 'layers.2.downsample.3.bias', 'layers.3.blocks.0.norm2.weight', 'layers.3.blocks.0.norm2.bias', 'layers.3.blocks.0.mlp.fc1.weight', 'layers.3.blocks.0.mlp.fc1.bias', 'layers.3.blocks.0.mlp.fc2.weight', 'layers.3.blocks.0.mlp.fc2.bias', 'layers.3.blocks.1.norm2.weight', 'layers.3.blocks.1.norm2.bias', 'layers.3.blocks.1.mlp.fc1.weight', 'layers.3.blocks.1.mlp.fc1.bias', 'layers.3.blocks.1.mlp.fc2.weight', 'layers.3.blocks.1.mlp.fc2.bias', 'outnorm0.weight', 'outnorm0.bias', 'outnorm1.weight', 'outnorm1.bias', 'outnorm2.weight', 'outnorm2.bias', 'outnorm3.weight', 'outnorm3.bias'], unexpected_keys=['classifier.norm.weight', 'classifier.norm.bias', 'classifier.head.weight', 'classifier.head.bias', 'layers.0.downsample.reduction.weight', 'layers.0.downsample.norm.weight', 'layers.0.downsample.norm.bias', 'layers.1.downsample.reduction.weight', 'layers.1.downsample.norm.weight', 'layers.1.downsample.norm.bias', 'layers.2.blocks.9.norm.weight', 'layers.2.blocks.9.norm.bias', 'layers.2.blocks.9.op.x_proj_weight', 'layers.2.blocks.9.op.dt_projs_weight', 'layers.2.blocks.9.op.dt_projs_bias', 'layers.2.blocks.9.op.A_logs', 'layers.2.blocks.9.op.Ds', 'layers.2.blocks.9.op.in_proj.weight', 'layers.2.blocks.9.op.conv2d.weight', 'layers.2.blocks.9.op.conv2d.bias', 'layers.2.blocks.9.op.out_norm.weight', 'layers.2.blocks.9.op.out_norm.bias', 'layers.2.blocks.9.op.out_proj.weight', 'layers.2.blocks.10.norm.weight', 'layers.2.blocks.10.norm.bias', 'layers.2.blocks.10.op.x_proj_weight', 'layers.2.blocks.10.op.dt_projs_weight', 'layers.2.blocks.10.op.dt_projs_bias', 'layers.2.blocks.10.op.A_logs', 'layers.2.blocks.10.op.Ds', 'layers.2.blocks.10.op.in_proj.weight', 'layers.2.blocks.10.op.conv2d.weight', 'layers.2.blocks.10.op.conv2d.bias', 'layers.2.blocks.10.op.out_norm.weight', 'layers.2.blocks.10.op.out_norm.bias', 'layers.2.blocks.10.op.out_proj.weight', 'layers.2.blocks.11.norm.weight', 'layers.2.blocks.11.norm.bias', 'layers.2.blocks.11.op.x_proj_weight', 'layers.2.blocks.11.op.dt_projs_weight', 'layers.2.blocks.11.op.dt_projs_bias', 'layers.2.blocks.11.op.A_logs', 'layers.2.blocks.11.op.Ds', 'layers.2.blocks.11.op.in_proj.weight', 'layers.2.blocks.11.op.conv2d.weight', 'layers.2.blocks.11.op.conv2d.bias', 'layers.2.blocks.11.op.out_norm.weight', 'layers.2.blocks.11.op.out_norm.bias', 'layers.2.blocks.11.op.out_proj.weight', 'layers.2.blocks.12.norm.weight', 'layers.2.blocks.12.norm.bias', 'layers.2.blocks.12.op.x_proj_weight', 'layers.2.blocks.12.op.dt_projs_weight', 'layers.2.blocks.12.op.dt_projs_bias', 'layers.2.blocks.12.op.A_logs', 'layers.2.blocks.12.op.Ds', 'layers.2.blocks.12.op.in_proj.weight', 'layers.2.blocks.12.op.conv2d.weight', 'layers.2.blocks.12.op.conv2d.bias', 'layers.2.blocks.12.op.out_norm.weight', 'layers.2.blocks.12.op.out_norm.bias', 'layers.2.blocks.12.op.out_proj.weight', 'layers.2.blocks.13.norm.weight', 'layers.2.blocks.13.norm.bias', 'layers.2.blocks.13.op.x_proj_weight', 'layers.2.blocks.13.op.dt_projs_weight', 'layers.2.blocks.13.op.dt_projs_bias', 'layers.2.blocks.13.op.A_logs', 'layers.2.blocks.13.op.Ds', 'layers.2.blocks.13.op.in_proj.weight', 'layers.2.blocks.13.op.conv2d.weight', 'layers.2.blocks.13.op.conv2d.bias', 'layers.2.blocks.13.op.out_norm.weight', 'layers.2.blocks.13.op.out_norm.bias', 'layers.2.blocks.13.op.out_proj.weight', 'layers.2.blocks.14.norm.weight', 'layers.2.blocks.14.norm.bias', 'layers.2.blocks.14.op.x_proj_weight', 'layers.2.blocks.14.op.dt_projs_weight', 'layers.2.blocks.14.op.dt_projs_bias', 'layers.2.blocks.14.op.A_logs', 'layers.2.blocks.14.op.Ds', 'layers.2.blocks.14.op.in_proj.weight', 'layers.2.blocks.14.op.conv2d.weight', 'layers.2.blocks.14.op.conv2d.bias', 'layers.2.blocks.14.op.out_norm.weight', 'layers.2.blocks.14.op.out_norm.bias', 'layers.2.blocks.14.op.out_proj.weight', 'layers.2.blocks.15.norm.weight', 'layers.2.blocks.15.norm.bias', 'layers.2.blocks.15.op.x_proj_weight', 'layers.2.blocks.15.op.dt_projs_weight', 'layers.2.blocks.15.op.dt_projs_bias', 'layers.2.blocks.15.op.A_logs', 'layers.2.blocks.15.op.Ds', 'layers.2.blocks.15.op.in_proj.weight', 'layers.2.blocks.15.op.conv2d.weight', 'layers.2.blocks.15.op.conv2d.bias', 'layers.2.blocks.15.op.out_norm.weight', 'layers.2.blocks.15.op.out_norm.bias', 'layers.2.blocks.15.op.out_proj.weight', 'layers.2.blocks.16.norm.weight', 'layers.2.blocks.16.norm.bias', 'layers.2.blocks.16.op.x_proj_weight', 'layers.2.blocks.16.op.dt_projs_weight', 'layers.2.blocks.16.op.dt_projs_bias', 'layers.2.blocks.16.op.A_logs', 'layers.2.blocks.16.op.Ds', 'layers.2.blocks.16.op.in_proj.weight', 'layers.2.blocks.16.op.conv2d.weight', 'layers.2.blocks.16.op.conv2d.bias', 'layers.2.blocks.16.op.out_norm.weight', 'layers.2.blocks.16.op.out_norm.bias', 'layers.2.blocks.16.op.out_proj.weight', 'layers.2.blocks.17.norm.weight', 'layers.2.blocks.17.norm.bias', 'layers.2.blocks.17.op.x_proj_weight', 'layers.2.blocks.17.op.dt_projs_weight', 'layers.2.blocks.17.op.dt_projs_bias', 'layers.2.blocks.17.op.A_logs', 'layers.2.blocks.17.op.Ds', 'layers.2.blocks.17.op.in_proj.weight', 'layers.2.blocks.17.op.conv2d.weight', 'layers.2.blocks.17.op.conv2d.bias', 'layers.2.blocks.17.op.out_norm.weight', 'layers.2.blocks.17.op.out_norm.bias', 'layers.2.blocks.17.op.out_proj.weight', 'layers.2.blocks.18.norm.weight', 'layers.2.blocks.18.norm.bias', 'layers.2.blocks.18.op.x_proj_weight', 'layers.2.blocks.18.op.dt_projs_weight', 'layers.2.blocks.18.op.dt_projs_bias', 'layers.2.blocks.18.op.A_logs', 'layers.2.blocks.18.op.Ds', 'layers.2.blocks.18.op.in_proj.weight', 'layers.2.blocks.18.op.conv2d.weight', 'layers.2.blocks.18.op.conv2d.bias', 'layers.2.blocks.18.op.out_norm.weight', 'layers.2.blocks.18.op.out_norm.bias', 'layers.2.blocks.18.op.out_proj.weight', 'layers.2.blocks.19.norm.weight', 'layers.2.blocks.19.norm.bias', 'layers.2.blocks.19.op.x_proj_weight', 'layers.2.blocks.19.op.dt_projs_weight', 'layers.2.blocks.19.op.dt_projs_bias', 'layers.2.blocks.19.op.A_logs', 'layers.2.blocks.19.op.Ds', 'layers.2.blocks.19.op.in_proj.weight', 'layers.2.blocks.19.op.conv2d.weight', 'layers.2.blocks.19.op.conv2d.bias', 'layers.2.blocks.19.op.out_norm.weight', 'layers.2.blocks.19.op.out_norm.bias', 'layers.2.blocks.19.op.out_proj.weight', 'layers.2.blocks.20.norm.weight', 'layers.2.blocks.20.norm.bias', 'layers.2.blocks.20.op.x_proj_weight', 'layers.2.blocks.20.op.dt_projs_weight', 'layers.2.blocks.20.op.dt_projs_bias', 'layers.2.blocks.20.op.A_logs', 'layers.2.blocks.20.op.Ds', 'layers.2.blocks.20.op.in_proj.weight', 'layers.2.blocks.20.op.conv2d.weight', 'layers.2.blocks.20.op.conv2d.bias', 'layers.2.blocks.20.op.out_norm.weight', 'layers.2.blocks.20.op.out_norm.bias', 'layers.2.blocks.20.op.out_proj.weight', 'layers.2.blocks.21.norm.weight', 'layers.2.blocks.21.norm.bias', 'layers.2.blocks.21.op.x_proj_weight', 'layers.2.blocks.21.op.dt_projs_weight', 'layers.2.blocks.21.op.dt_projs_bias', 'layers.2.blocks.21.op.A_logs', 'layers.2.blocks.21.op.Ds', 'layers.2.blocks.21.op.in_proj.weight', 'layers.2.blocks.21.op.conv2d.weight', 'layers.2.blocks.21.op.conv2d.bias', 'layers.2.blocks.21.op.out_norm.weight', 'layers.2.blocks.21.op.out_norm.bias', 'layers.2.blocks.21.op.out_proj.weight', 'layers.2.blocks.22.norm.weight', 'layers.2.blocks.22.norm.bias', 'layers.2.blocks.22.op.x_proj_weight', 'layers.2.blocks.22.op.dt_projs_weight', 'layers.2.blocks.22.op.dt_projs_bias', 'layers.2.blocks.22.op.A_logs', 'layers.2.blocks.22.op.Ds', 'layers.2.blocks.22.op.in_proj.weight', 'layers.2.blocks.22.op.conv2d.weight', 'layers.2.blocks.22.op.conv2d.bias', 'layers.2.blocks.22.op.out_norm.weight', 'layers.2.blocks.22.op.out_norm.bias', 'layers.2.blocks.22.op.out_proj.weight', 'layers.2.blocks.23.norm.weight', 'layers.2.blocks.23.norm.bias', 'layers.2.blocks.23.op.x_proj_weight', 'layers.2.blocks.23.op.dt_projs_weight', 'layers.2.blocks.23.op.dt_projs_bias', 'layers.2.blocks.23.op.A_logs', 'layers.2.blocks.23.op.Ds', 'layers.2.blocks.23.op.in_proj.weight', 'layers.2.blocks.23.op.conv2d.weight', 'layers.2.blocks.23.op.conv2d.bias', 'layers.2.blocks.23.op.out_norm.weight', 'layers.2.blocks.23.op.out_norm.bias', 'layers.2.blocks.23.op.out_proj.weight', 'layers.2.blocks.24.norm.weight', 'layers.2.blocks.24.norm.bias', 'layers.2.blocks.24.op.x_proj_weight', 'layers.2.blocks.24.op.dt_projs_weight', 'layers.2.blocks.24.op.dt_projs_bias', 'layers.2.blocks.24.op.A_logs', 'layers.2.blocks.24.op.Ds', 'layers.2.blocks.24.op.in_proj.weight', 'layers.2.blocks.24.op.conv2d.weight', 'layers.2.blocks.24.op.conv2d.bias', 'layers.2.blocks.24.op.out_norm.weight', 'layers.2.blocks.24.op.out_norm.bias', 'layers.2.blocks.24.op.out_proj.weight', 'layers.2.blocks.25.norm.weight', 'layers.2.blocks.25.norm.bias', 'layers.2.blocks.25.op.x_proj_weight', 'layers.2.blocks.25.op.dt_projs_weight', 'layers.2.blocks.25.op.dt_projs_bias', 'layers.2.blocks.25.op.A_logs', 'layers.2.blocks.25.op.Ds', 'layers.2.blocks.25.op.in_proj.weight', 'layers.2.blocks.25.op.conv2d.weight', 'layers.2.blocks.25.op.conv2d.bias', 'layers.2.blocks.25.op.out_norm.weight', 'layers.2.blocks.25.op.out_norm.bias', 'layers.2.blocks.25.op.out_proj.weight', 'layers.2.blocks.26.norm.weight', 'layers.2.blocks.26.norm.bias', 'layers.2.blocks.26.op.x_proj_weight', 'layers.2.blocks.26.op.dt_projs_weight', 'layers.2.blocks.26.op.dt_projs_bias', 'layers.2.blocks.26.op.A_logs', 'layers.2.blocks.26.op.Ds', 'layers.2.blocks.26.op.in_proj.weight', 'layers.2.blocks.26.op.conv2d.weight', 'layers.2.blocks.26.op.conv2d.bias', 'layers.2.blocks.26.op.out_norm.weight', 'layers.2.blocks.26.op.out_norm.bias', 'layers.2.blocks.26.op.out_proj.weight', 'layers.2.downsample.reduction.weight', 'layers.2.downsample.norm.weight', 'layers.2.downsample.norm.bias'])

MzeroMiko commented 7 months ago

Are you using the right config for that checkpoint? it seems that you are loading a state dict of tiny model into small or base model.

MoetezKd commented 7 months ago

this is what i m using to load it now although no message of mismatch showing rn but don t know hot to integrate the config file

build = import_abspy( "models", os.path.join(os.path.dirname(os.path.abspath(file)), "./classification/"), ) def MambaBlock(pretrained): mod=torch.load(pretrained,map_location='cpu') Backbone_VSSM: nn.Module = build.vmamba.Backbone_VSSM() Backbone_VSSM._load_from_state_dict(state_dict=mod['model'],prefix='', local_metadata=None, strict=True, missing_keys=[], unexpected_keys=[], error_msgs=[]) return Backbone_VSSM

MzeroMiko commented 7 months ago

I think the code in detection/model.py may helps.