state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.72k stars 1.07k forks source link

RuntimeError: CUDA error: no kernel image is available for execution on the device on 3xP40 #85

Open DewEfresh opened 9 months ago

DewEfresh commented 9 months ago

I'm having an issue trying to get mamba running on 3xP40. The model will load into vram but then crashes with "RuntimeError: CUDA error: no kernel image is available for execution on the device" I tried the fix here https://github.com/state-spaces/mamba/issues/40 and it didn't work for me.

(textgen) user@dev01:~/mamba$ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba-2.8b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2 Loading model state-spaces/mamba-2.8b Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained. Number of parameters: 2768345600 Traceback (most recent call last): File "/home/user/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 80, in out = fn() ^^^^ File "/home/user/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 55, in fn = lambda: model.generate( ^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/mamba_ssm/utils/generation.py", line 244, in generate output = decode( ^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/mamba_ssm/utils/generation.py", line 145, in decode model._decoding_cache = update_graph_cache( ^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/mamba_ssm/utils/generation.py", line 305, in update_graph_cache cache.callables[batch_size, decoding_seqlen] = capture_graph( ^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/mamba_ssm/utils/generation.py", line 339, in capture_graph logits = model( ^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 233, in forward hidden_states = self.backbone(input_ids, inference_params=inference_params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 155, in forward hidden_states, residual = layer( ^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/mamba_ssm/modules/mamba_simple.py", line 349, in forward hidden_states = self.mixer(hidden_states, inference_params=inference_params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/mamba_ssm/modules/mambasimple.py", line 131, in forward out, , _ = self.step(hidden_states, conv_state, ssm_state) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/mamba_ssm/modules/mamba_simple.py", line 223, in step x = causal_conv1d_update( ^^^^^^^^^^^^^^^^^^^^^ File "/home/user/miniconda3/envs/textgen/lib/python3.11/site-packages/causal_conv1d/causal_conv1d_interface.py", line 83, in causal_conv1d_update return causal_conv1d_cuda.causal_conv1d_update(x, conv_state, weight, bias, activation) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ RuntimeError: CUDA error: no kernel image is available for execution on the device CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

import torch from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16 x = torch.randn(batch, length, dim).to("cuda") model = Mamba(

This module uses roughly 3 expand d_model^2 parameters

d_model=dim, # Model dimension d_model
d_state=16,  # SSM state expansion factor
d_conv=4,    # Local convolution width
expand=2,    # Block expansion factor

).to("cuda") y = model(x) assert y.shape == x.shape


RuntimeError Traceback (most recent call last) Cell In[2], line 14 6 x = torch.randn(batch, length, dim).to("cuda") 7 model = Mamba( 8 # This module uses roughly 3 expand d_model^2 parameters 9 d_model=dim, # Model dimension d_model (...) 12 expand=2, # Block expansion factor 13 ).to("cuda") ---> 14 y = model(x) 15 assert y.shape == x.shape

File ~/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File ~/miniconda3/envs/textgen/lib/python3.11/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) ... RuntimeError: CUDA error: no kernel image is available for execution on the device CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. For debugging consider passing CUDA_LAUNCH_BLOCKING=1. Compile with TORCH_USE_CUDA_DSA to enable device-side assertions. Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

import torch; print(torch.version.cuda) !nvcc -V

12.1 nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2023 NVIDIA Corporation Built on Mon_Apr__3_17:16:06_PDT_2023 Cuda compilation tools, release 12.1, V12.1.105 Build cuda_12.1.r12.1/compiler.32688072_0

+---------------------------------------------------------------------------------------+ | NVIDIA-SMI 530.30.02 Driver Version: 530.30.02 CUDA Version: 12.1 | |-----------------------------------------+----------------------+----------------------+ | GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+======================+======================| | 0 Tesla P40 Off| 00000000:01:00.0 Off | Off | | N/A 62C P0 55W / 250W| 196MiB / 24576MiB | 0% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ | 1 Tesla P40 Off| 00000000:06:11.0 Off | Off | | N/A 35C P8 10W / 250W| 2MiB / 24576MiB | 0% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+ | 2 Tesla P40 Off| 00000000:06:1B.0 Off | Off | | N/A 32C P8 10W / 250W| 2MiB / 24576MiB | 0% Default | | | | N/A | +-----------------------------------------+----------------------+----------------------+

Package Version


absl-py 2.0.0 accelerate 0.25.0 aiofiles 23.2.1 aiohttp 3.9.1 aiosignal 1.3.1 altair 5.2.0 annotated-types 0.6.0 anyio 3.7.1 appdirs 1.4.4 asttokens 2.4.1 attributedict 0.3.0 attrs 23.1.0 auto-gptq 0.6.0+cu121 autoawq 0.1.7 Automat 22.10.0 bitsandbytes 0.41.1 blessings 1.7 Brotli 1.0.9 buildtools 1.0.6 cachetools 5.3.2 causal-conv1d 1.1.1 certifi 2022.12.7 cffi 1.16.0 chardet 5.2.0 charset-normalizer 2.1.1 click 8.1.7 codecov 2.1.13 colorama 0.4.6 coloredlogs 15.0.1 colour-runner 0.1.1 comm 0.1.4 constantly 23.10.4 contourpy 1.2.0 coverage 7.3.2 cramjam 2.7.0 cryptography 41.0.7 ctransformers 0.2.27+cu121 cycler 0.12.1 DataProperty 1.0.1 datasets 2.16.0 debugpy 1.6.7 decorator 5.1.1 deepdiff 6.7.1 dill 0.3.7 diskcache 5.6.3 distlib 0.3.7 distro 1.8.0 docker-pycreds 0.4.0 docopt 0.6.2 einops 0.7.0 exceptiongroup 1.2.0 executing 2.0.1 exllama 0.0.18+cu121 exllamav2 0.0.11+cu121 fastapi 0.104.1 fastparquet 2023.10.1 ffmpy 0.3.1 filelock 3.13.1 flash-attn 2.3.4 fonttools 4.46.0 frozenlist 1.4.0 fsspec 2023.10.0 furl 2.1.3 gekko 1.0.6 gitdb 4.0.11 GitPython 3.1.40 gmpy2 2.1.2 google-auth 2.25.1 google-auth-oauthlib 1.1.0 gptq-for-llama 0.1.1+cu121 gradio 3.50.2 gradio_client 0.6.1 greenlet 3.0.3 grpcio 1.59.3 h11 0.14.0 httpcore 1.0.2 httpx 0.25.2 huggingface-hub 0.20.1 humanfriendly 10.0 hyperlink 21.0.0 idna 3.4 importlib-metadata 7.0.0 importlib-resources 6.1.1 incremental 22.10.0 inspecta 0.1.3 ipykernel 6.26.0 ipython 8.18.1 ipywidgets 8.1.1 jedi 0.19.1 Jinja2 3.1.2 joblib 1.3.2 jsonlines 4.0.0 jsonschema 4.20.0 jsonschema-specifications 2023.11.2 jupyter_client 8.6.0 jupyter_core 5.5.0 jupyterlab-widgets 3.0.9 kiwisolver 1.4.5 llama_cpp_python 0.2.24+cpuavx2 llama_cpp_python_cuda 0.2.24+cu121 llama_cpp_python_cuda_tensorcores 0.2.24+cu121 lm-eval 0.3.0 mamba-ssm 1.1.1 Markdown 3.5.1 markdown-it-py 3.0.0 MarkupSafe 2.1.3 matplotlib 3.8.2 matplotlib-inline 0.1.6 mbstrdecoder 1.1.3 mdurl 0.1.2 mkl-fft 1.3.8 mkl-random 1.2.4 mkl-service 2.4.0 mpmath 1.3.0 multidict 6.0.4 multiprocess 0.70.15 nest-asyncio 1.5.8 networkx 3.0 ninja 1.11.1.1 nltk 3.8.1 numexpr 2.8.7 numpy 1.24.4 oauthlib 3.2.2 openai 1.3.7 optimum 1.16.1 ordered-set 4.1.0 orderedmultidict 1.0.1 orjson 3.9.10 packaging 23.2 pandas 2.1.4 parso 0.8.3 pathvalidate 3.2.0 peft 0.7.1 pexpect 4.8.0 pickleshare 0.7.5 Pillow 10.1.0 pip 23.3.1 platformdirs 4.1.0 pluggy 1.3.0 portalocker 2.8.2 prompt-toolkit 3.0.42 protobuf 4.23.4 psutil 5.9.5 ptyprocess 0.7.0 pure-eval 0.2.2 py-cpuinfo 9.0.0 pyarrow 14.0.1 pyarrow-hotfix 0.6 pyasn1 0.5.1 pyasn1-modules 0.3.0 pybind11 2.11.1 pycountry 22.3.5 pycparser 2.21 pydantic 2.5.2 pydantic_core 2.14.5 pydub 0.25.1 Pygments 2.17.2 pyOpenSSL 23.2.0 pyparsing 3.1.1 pyproject-api 1.6.1 PySocks 1.7.1 pytablewriter 1.2.0 python-dateutil 2.8.2 python-multipart 0.0.6 pytz 2023.3.post1 PyYAML 6.0.1 pyzmq 25.1.0 redo 2.0.4 referencing 0.31.1 regex 2023.10.3 requests 2.31.0 requests-oauthlib 1.3.1 rich 13.7.0 rootpath 0.1.1 rouge 1.0.1 rouge-score 0.1.2 rpds-py 0.13.2 rsa 4.9 sacrebleu 1.5.0 safetensors 0.4.1 scikit-learn 1.3.2 scipy 1.11.4 semantic-version 2.10.0 sentencepiece 0.1.99 sentry-sdk 1.38.0 setproctitle 1.3.3 setuptools 68.0.0 simplejson 3.19.2 six 1.16.0 smmap 5.0.1 sniffio 1.3.0 SQLAlchemy 2.0.23 sqlitedict 2.1.0 stack-data 0.6.2 starlette 0.27.0 sympy 1.12 tabledata 1.3.3 tabulate 0.9.0 tcolorpy 0.1.4 tensorboard 2.15.1 tensorboard-data-server 0.7.2 termcolor 2.4.0 texttable 1.7.0 threadpoolctl 3.2.0 tokenizers 0.15.0 toml 0.10.2 toolz 0.12.0 torch 2.1.2 torch-grammar 0.3.3 torchaudio 2.1.2 torchvision 0.16.2 tornado 6.3.3 tox 4.11.4 tqdm 4.66.1 tqdm-multiprocess 0.0.11 traitlets 5.14.0 transformers 4.36.2 triton 2.1.0 Twisted 23.10.0 typepy 1.3.2 typing_extensions 4.9.0 tzdata 2023.3 urllib3 1.26.18 uvicorn 0.24.0.post1 virtualenv 20.25.0 wandb 0.16.1 wcwidth 0.2.12 websockets 11.0.3 Werkzeug 3.0.1 wheel 0.41.2 widgetsnbextension 4.0.9 xxhash 3.4.1 yarl 1.9.4 zipp 3.17.0 zope.interface 6.1 zstandard 0.22.

masc-it commented 9 months ago

Have you tried running it on a single gpu? (Maybe a smaller one) If that's the case, I think there are some issues with the custom operators and DDP.

Or, it could be the fact that P40 is missing some kernels that the mamba optimized version uses. On my ampere-series card no probs whatsoever.

Tworan commented 8 months ago

I came into the same issue with gtx1080ti which compute capability is compute_60. In the setup.py file, I see both the mamba_ssm and causal_conv1d only compile the code with compute_70, compute_80 and compute_90 except compute_60 which is corresponding to the sm_60: https://github.com/state-spaces/mamba/blob/86a3a902ca4189689aabf1c09174235024c7aede/setup.py#L108-L114

So I add two line to both setup.py file of mamba_ssm and causal_conv1d, then reinstall both packages and it works for me.

    #add
    cc_flag.append("-gencode")
    cc_flag.append("arch=compute_60,code=sm_60")
    # 
    cc_flag.append("-gencode")
    cc_flag.append("arch=compute_70,code=sm_70") 
    cc_flag.append("-gencode") 
    cc_flag.append("arch=compute_80,code=sm_80") 
    if bare_metal_version >= Version("11.8"): 
        cc_flag.append("-gencode") 
        cc_flag.append("arch=compute_90,code=sm_90")

And I think this may be because the Pascal architecture GPU(compute_60) does not support the bf16 data type. And this may cause the functions using bf16 to be unusable.

SamsongB commented 8 months ago

I also see the same issue with GM200 [GeForce GTX TITAN X], tried the above solution with compute_60, but still cannot seem to run. is there an update regarding this?

hhhhpaaa commented 8 months ago

I built a wheel under the Python 3.10, PyTorch 2.1, and CUDA 11.8 environment to support the compute_60. For details, please refer to the link. @SamsongB @DewEfresh

a987042035 commented 6 months ago

@DewEfresh Hello, my dear. Have you solved your problem? The single card p40 I use also has this problem. Please help me with my cuda version 11.8