MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
2.21k stars 143 forks source link

Assertion Error #331

Open Karn3003 opened 8 hours ago

Karn3003 commented 8 hours ago

Hi @MzeroMiko, Thanks for the good work. I have a question regarding the following example:

    import torch
    import torch.nn as nn

    # Instantiate the model
    model = SS2D(
        d_model=96,         # Input feature dimension
        d_state=16,         # State dimension
        ssm_ratio=2.0,      # Scaling ratio
        dt_rank="auto",     # Automatic rank selection
        d_conv=3,           # Kernel size for convolution
        dropout=0.1,        # Dropout rate
        forward_type="v2",  # Forward mode: can be "v0", "v2", etc.
        channel_first=True, # Input tensor format
    ).cuda()

    # Dummy input tensor (shape depends on channel_first)
    batch_size = 4
    height, width = 64, 64
    channels = 96
    x = torch.randn(batch_size, 96, 64, 64).cuda()  # (B, C, H, W)

    # Forward pass
    output = model(x)

    # Print output shape
    print("Output shape:", output.shape)

I am getting the following error:

    assert selective_scan_backend in [None, "oflex", "mamba", "torch"]
        [348](vscode-notebook-cell:?execution_count=4&line=348)     _scan_mode = dict(cross2d=0, unidi=1, bidi=2, cascade2d=-1).get(scan_mode, None) if isinstance(scan_mode, str) else scan_mode # for debug
        [349](vscode-notebook-cell:?execution_count=4&line=349)     assert isinstance(_scan_mode, int)

    AssertionError: