state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.22k stars 1.13k forks source link

Error with Mamba2 #413

Open Adele0108 opened 5 months ago

Adele0108 commented 5 months ago

Hi, I just try the test code of mamba-2 like this: ` from mamba_ssm import Mamba2 import torch batch, length, dim = 2, 64, 1024 x = torch.randn(batch, length, dim).to("cuda") model = Mamba2(

This module uses roughly 3 expand d_model^2 parameters

d_model=dim, # Model dimension d_model
d_state=64,  # SSM state expansion factor, typically 64 or 128
d_conv=4,    # Local convolution width
expand=2,    # Block expansion factor
headdim=128

).to("cuda") y = model(x) assert y.shape == x.shape print("Mamba2 model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) print('x.shape:', x.shape, 'y.shape:', y.shape) `

But there are some errors:

` File "/opt/anaconda3/envs/medfusion-2d/lib/python3.8/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 761, in forward causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:

  1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor Invoked with: tensor([[[ 0.6263, -0.1259, 0.6615, ..., 0.1121, 0.1023, -0.2840], [-0.5732, 1.5656, 0.5829, ..., 0.6564, 0.7546, 0.1331], [ 0.4265, -0.1785, 0.1311, ..., 0.6014, -1.0048, 0.0453], ..., [ 0.1693, -0.7641, -0.0408, ..., -0.3669, -0.2489, -0.2052], [ 0.8796, -0.5051, 0.3856, ..., 0.6248, 0.2461, -0.6594], [-0.6611, 0.2886, 0.4760, ..., -0.0319, 0.6962, -1.1070]],

    [[ 0.3243,  0.7392, -0.6660,  ..., -0.2669, -0.3460,  0.1921],
     [-0.1172,  0.2228, -0.1020,  ...,  1.1721,  2.1293,  0.4847],
     [ 0.0962,  0.2899, -0.6043,  ..., -0.6814,  0.4837,  0.0075],
     ...,
     [ 0.1357, -1.0081,  0.3166,  ..., -0.4532,  0.9043, -0.1286],
     [ 0.6356,  0.1391, -0.3242,  ...,  0.3308,  0.3722, -0.5956],
     [ 0.7242, -0.3001,  0.8165,  ...,  0.5277,  1.1039, -0.9327]]],

    device='cuda:0', requires_grad=True), tensor([[-0.1640, 0.4310, -0.2341, 0.2770], [ 0.1296, -0.1512, 0.0115, 0.1537], [-0.0655, 0.3352, 0.2952, -0.3224], ..., [-0.2745, 0.0135, 0.3997, -0.2371], [ 0.4181, -0.0019, 0.1142, 0.1713], [-0.3888, 0.3710, 0.4792, 0.2264]], device='cuda:0', grad_fn=), Parameter containing: tensor([-0.3444, -0.2064, -0.3750, ..., 0.2153, -0.1905, -0.0108], device='cuda:0', requires_grad=True), None, None, None, True `

tridao commented 5 months ago

Please update causal_conv1d.

Adele0108 commented 5 months ago

Thanks for your prompt answer. After update, there is a new error: File "/opt/anaconda3/envs/medfusion-2d/lib/python3.8/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 761, in forward causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

Peilin-FF commented 3 months ago

Have you solved this updated problem ,I also meet it:causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

buAUFDkru commented 2 weeks ago

Have you solved this updated problem ,I also meet it:causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8

buAUFDkru commented 2 weeks ago

请更新causal_conv1d。

Traceback (most recent call last): File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/mnt/e/project2/accelerated_features-main/modules/training/train.py", line 330, in trainer.train() File "/mnt/e/project2/accelerated_features-main/modules/training/train.py", line 236, in train feats1, kpts1, hmap1 = self.net(p1) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, *kwargs) File "/mnt/e/project2/accelerated_features-main/modules/model5.py", line 154, in forward x3 = self.block3(x2) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/container.py", line 217, in forward input = module(input) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(args, kwargs) File "/mnt/e/project2/accelerated_features-main/modules/LightManbaXfeatNet.py", line 136, in forward x = self.conv1(x) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) File "/mnt/e/project2/accelerated_features-main/modules/LightManbaXfeatNet.py", line 35, in forward x_mamba = self.mamba(x_norm) + self.skip_scale x_flat File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl return forward_call(*args, kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py", line 183, in forward out = mamba_split_conv1d_scan_combined( File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 930, in mamba_split_conv1d_scan_combined return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outprojweight, outproj bias, headdim, ngroups, norm_before_gate) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply return super().apply(*args, *kwargs) # type: ignore[misc] File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd return fwd(args, **kwargs) File "/home/hu/anaconda3/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 779, in forward causal_conv1d_cuda.causal_conv1d_fwd(rearrange(xBC, "b s d -> b d s"), RuntimeError: causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8 what is stride?what can I do? Shape of x: torch.Size([3, 7600, 64]) Strides of x: (486400, 64, 1)