state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.75k stars 1.08k forks source link

Sample code error #283

Open lovjie opened 6 months ago

lovjie commented 6 months ago

I used the provided instance code and encountered this issue. How can I solve it? “ import torch from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16 x = torch.randn(batch, length, dim).to("cuda") model = Mamba(

This module uses roughly 3 expand d_model^2 parameters

d_model=dim, # Model dimension d_model
d_state=16,  # SSM state expansion factor
d_conv=4,    # Local convolution width
expand=2,    # Block expansion factor

).to("cuda") y = model(x)

print("y:",y.shape) assert y.shape == x.shape ”

Traceback (most recent call last): File "/usr/test.py", line 13, in y = model(x) File "/usr/local/miniconda3/envs/umt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/miniconda3/envs/umt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/usr/local/miniconda3/envs/umt/lib/python3.10/site-packages/mamba_ssm/modules/mamba_simple.py", line 146, in forward out = mamba_inner_fn( File "/usr/local/miniconda3/envs/umt/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 317, in mamba_inner_fn return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, File "/usr/local/miniconda3/envs/umt/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(args, kwargs) # type: ignore[misc] File "/usr/local/miniconda3/envs/umt/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd return fwd(*args, **kwargs) File "/usr/local/miniconda3/envs/umt/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 187, in forward conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported:

  1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: bool) -> torch.Tensor
BlenderWang9487 commented 6 months ago

I believe the issue may be attributed to the older version of the causal_conv1d interface.

Based on the error function signature:

(arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: bool)

I presume you are utilizing causal-conv1d<=1.0.2.

The authors suggest pip install causal-conv1d>=1.2.0 in the README.md. I recommend giving it a try.

uxhao-o commented 5 months ago

I have been successfully run. Environment follows: cuda 11.8 python 3.10.13 pytorch 2.1.1 causal_conv1d 1.1.1 mamba-ssm 1.2.0.post1

pip install torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1 --index-url https://download.pytorch.org/whl/cu118 
pip install causal_conv1d==1.1.1 
pip install mamba-ssm==1.2.0.post1