state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.73k stars 1.07k forks source link

Error in Running Benchmark Command #23

Open Unrealluver opened 10 months ago

Unrealluver commented 10 months ago

Greetings! Thanks for your great work! When I tried the benchmark code, I met the error below. Could you please share some possible solutions?

python benchmarks/benchmark_generation_mamba_simple.py --model-name "/home/x/VisionProjects/mamba/ckpts/mamba-130m" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
Loading model /home/x/VisionProjects/mamba/ckpts/mamba-130m
Number of parameters: 129135360
Traceback (most recent call last):
  File "<string>", line 21, in _layer_norm_fwd_1pass_kernel
KeyError: ('2-.-1-.-0-+-2-c-3-2-f-4-3-9-9-9-83ca8b715a9dc5f32dc1110973485f64-45375ed7aa3bacaed5f41dca33dc8ee0-6590aa19b3e9909e5c8a7254fb3b9328-e6da1445790e1250a9b68f17efc2dd18-7f2d2fed060f2e0fa46ef4e19e20c865-e1f133f98d04093da2078dfc51c36b72-056bca445a91d3175375bc8481ed1689-0db1785b8dc43452c61ef6d926ec11bb-6aff3b6e239e435b817994e60abc8cef', (torch.float16, torch.float16, torch.float16, None, None, torch.float32, None, torch.float32, 'i32', 'i32', 'i32', 'i32', 'i32', 'fp32'), (True, 1024, False, True, False), (True, True, True, (False,), (False,), True, (False,), True, (True, False), (True, False), (True, False), (True, False), (True, False), (False,)), 1, 2, False)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "benchmarks/benchmark_generation_mamba_simple.py", line 77, in <module>
    out = fn()
  File "benchmarks/benchmark_generation_mamba_simple.py", line 54, in <lambda>
    fn = lambda: model.generate(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 218, in generate
    output = decode(
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 127, in decode
    model._decoding_cache = update_graph_cache(
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 311, in update_graph_cache
    cache.callables[batch_size, decoding_seqlen] = capture_graph(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 345, in capture_graph
    logits = model(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward
    hidden_states = self.backbone(input_ids, inference_params=inference_params)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
    hidden_states, residual = layer(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
    hidden_states, residual = fused_add_norm_fn(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 77, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 77, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 65, in _bench
    return do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8))
  File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 146, in do_bench
    fn()
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 63, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 41, in _layer_norm_fwd_1pass_kernel
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 1687, in compile
    return CompiledKernel(fn, so_path, metadata, asm)
  File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 1700, in __init__
    mod = importlib.util.module_from_spec(spec)
  File "<frozen importlib._bootstrap>", line 556, in module_from_spec
  File "<frozen importlib._bootstrap_external>", line 1101, in create_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: /root/.triton/cache/767259c163b96d4d22c0eea24dd36494/_layer_norm_fwd_1pass_kernel.so: undefined symbol: cuLaunchKernel

The dependencies and libraries are shown below:

Package                      Version
---------------------------- ------------
absl-py                      1.4.0
accelerate                   0.21.0
aiofiles                     23.1.0
aiohttp                      3.8.4
aiosignal                    1.3.1
alembic                      1.10.3
altair                       4.2.2
anthropic                    0.3.2
anyio                        3.6.2
appdirs                      1.4.4
argon2-cffi                  23.1.0
argon2-cffi-bindings         21.2.0
arrow                        1.2.3
asttokens                    2.4.0
astunparse                   1.6.3
async-lru                    2.0.4
async-timeout                4.0.2
attrs                        22.2.0
Babel                        2.12.1
backcall                     0.2.0
beautifulsoup4               4.12.2
bert-score                   0.3.13
bleach                       6.0.0
blessed                      1.20.0
BLEURT                       0.0.2
cachetools                   5.3.1
causal-conv1d                1.0.0
certifi                      2022.12.7
cffi                         1.15.1
charset-normalizer           3.1.0
click                        8.1.3
cloudpickle                  2.2.1
cmake                        3.26.1
colorama                     0.4.6
comm                         0.1.4
contourpy                    1.0.7
cycler                       0.11.0
databricks-cli               0.17.6
datasets                     2.11.0
debugpy                      1.8.0
decorator                    5.1.1
deepspeed                    0.9.5
defusedxml                   0.7.1
dill                         0.3.6
distlib                      0.3.6
distro                       1.8.0
docker                       6.0.1
docker-pycreds               0.4.0
einops                       0.6.0
entrypoints                  0.4
exceptiongroup               1.1.3
executing                    1.2.0
fastapi                      0.95.0
fastjsonschema               2.18.0
ffmpy                        0.3.0
filelock                     3.10.7
fire                         0.5.0
flash-attn                   2.0.4
Flask                        2.2.3
flatbuffers                  23.5.26
fonttools                    4.39.3
fqdn                         1.5.1
frozenlist                   1.3.3
fsspec                       2023.10.0
gast                         0.5.4
gitdb                        4.0.10
GitPython                    3.1.31
google-auth                  2.23.1
google-auth-oauthlib         1.0.0
google-pasta                 0.2.0
gpustat                      1.1.1
gradio                       3.50.2
gradio_client                0.6.1
greenlet                     2.0.2
grpcio                       1.53.0
gunicorn                     20.1.0
h11                          0.14.0
h5py                         3.9.0
hjson                        3.1.0
httpcore                     0.17.0
httpx                        0.24.0
huggingface-hub              0.19.4
idna                         3.4
importlib-metadata           6.4.1
importlib-resources          6.1.0
ipdb                         0.13.13
ipykernel                    6.25.2
ipython                      8.15.0
ipython-genutils             0.2.0
ipywidgets                   8.1.1
isoduration                  20.11.0
itsdangerous                 2.1.2
jedi                         0.19.0
Jinja2                       3.1.2
joblib                       1.2.0
json5                        0.9.14
jsonpointer                  2.4
jsonschema                   4.19.0
jsonschema-specifications    2023.7.1
jupyter                      1.0.0
jupyter_client               8.3.1
jupyter-console              6.6.3
jupyter_core                 5.3.1
jupyter-events               0.7.0
jupyter-lsp                  2.2.0
jupyter_server               2.7.3
jupyter_server_terminals     0.4.4
jupyterlab                   4.0.6
jupyterlab-pygments          0.2.2
jupyterlab_server            2.25.0
jupyterlab-widgets           3.0.9
keras                        2.14.0
kiwisolver                   1.4.4
libclang                     16.0.6
libretranslatepy             2.1.1
linkify-it-py                2.0.0
lit                          16.0.0
llvmlite                     0.39.1
lxml                         4.9.2
Mako                         1.2.4
mamba-ssm                    1.0.1
Markdown                     3.4.3
markdown-it-py               2.2.0
markdown2                    2.4.8
MarkupSafe                   2.1.2
matplotlib                   3.7.1
matplotlib-inline            0.1.6
mdit-py-plugins              0.3.3
mdurl                        0.1.2
mistune                      3.0.1
ml-dtypes                    0.2.0
mlflow                       2.2.2
mpmath                       1.3.0
msgpack                      1.0.5
multidict                    6.0.4
multiprocess                 0.70.14
nbclient                     0.8.0
nbconvert                    7.8.0
nbformat                     5.9.2
nest-asyncio                 1.5.8
networkx                     3.1
ninja                        1.11.1
nltk                         3.8.1
notebook                     7.0.3
notebook_shim                0.2.3
numba                        0.56.4
numpy                        1.23.5
nvidia-cublas-cu11           11.10.3.66
nvidia-cuda-cupti-cu11       11.7.101
nvidia-cuda-nvrtc-cu11       11.7.99
nvidia-cuda-runtime-cu11     11.7.99
nvidia-cudnn-cu11            8.5.0.96
nvidia-cufft-cu11            10.9.0.58
nvidia-curand-cu11           10.2.10.91
nvidia-cusolver-cu11         11.4.0.1
nvidia-cusparse-cu11         11.7.4.91
nvidia-ml-py                 12.535.108
nvidia-nccl-cu11             2.14.3
nvidia-nvtx-cu11             11.7.91
oauthlib                     3.2.2
openai                       0.27.4
opt-einsum                   3.3.0
orjson                       3.8.10
overrides                    7.4.0
packaging                    23.0
pandas                       2.0.0
pandocfilters                1.5.0
parso                        0.8.3
pathtools                    0.1.2
pexpect                      4.8.0
pickleshare                  0.7.5
Pillow                       9.3.0
pip                          23.3.1
platformdirs                 3.2.0
portalocker                  2.8.2
prometheus-client            0.17.0
prompt-toolkit               3.0.39
protobuf                     4.24.3
psutil                       5.9.4
ptyprocess                   0.7.0
pure-eval                    0.2.2
py-cpuinfo                   9.0.0
pyarrow                      11.0.0
pyasn1                       0.5.0
pyasn1-modules               0.3.0
pycparser                    2.21
pydantic                     1.10.7
pydub                        0.25.1
Pygments                     2.15.0
PyJWT                        2.6.0
pyparsing                    3.0.9
pyrsistent                   0.19.3
python-dateutil              2.8.2
python-json-logger           2.0.7
python-multipart             0.0.6
pytz                         2022.7.1
PyYAML                       6.0
pyzmq                        25.1.1
qtconsole                    5.4.4
QtPy                         2.4.0
querystring-parser           1.2.4
ray                          2.3.1
referencing                  0.30.2
regex                        2023.3.23
requests                     2.31.0
requests-oauthlib            1.3.1
responses                    0.18.0
rfc3339-validator            0.1.4
rfc3986-validator            0.1.1
rouge-score                  0.1.2
rpds-py                      0.10.3
rsa                          4.9
sacrebleu                    2.3.1
safetensors                  0.3.1
scikit-learn                 1.2.2
scipy                        1.10.1
seaborn                      0.12.2
semantic-version             2.10.0
Send2Trash                   1.8.2
sentencepiece                0.1.97
sentry-sdk                   1.19.1
setproctitle                 1.3.2
setuptools                   65.6.3
shap                         0.41.0
shortuuid                    1.0.11
six                          1.16.0
slicer                       0.0.7
smmap                        5.0.0
sniffio                      1.3.0
soupsieve                    2.5
SQLAlchemy                   2.0.9
sqlparse                     0.4.3
stack-data                   0.6.2
starlette                    0.26.1
svgwrite                     1.4.3
sympy                        1.11.1
tabulate                     0.9.0
tensor-parallel              1.2.0
tensorboard                  2.14.0
tensorboard-data-server      0.7.1
tensorboardX                 2.6
tensorflow                   2.14.0
tensorflow-estimator         2.14.0
tensorflow-io-gcs-filesystem 0.34.0
termcolor                    2.2.0
terminado                    0.17.1
tf-slim                      1.1.0
threadpoolctl                3.1.0
tinycss2                     1.2.1
tokenizers                   0.15.0
tomli                        2.0.1
toolz                        0.12.0
torch                        2.0.1+cu118
torchaudio                   2.0.1+cu118
torchvision                  0.15.1+cu118
tornado                      6.3.3
tqdm                         4.65.0
traitlets                    5.10.0
transformers                 4.35.2
translate                    3.6.1
triton                       2.0.0
typing_extensions            4.5.0
tzdata                       2023.3
uc-micro-py                  1.0.1
uri-template                 1.3.0
urllib3                      2.0.5
uvicorn                      0.21.1
virtualenv                   20.21.0
wandb                        0.14.2
wavedrom                     2.0.3.post3
wcwidth                      0.2.6
webcolors                    1.13
webencodings                 0.5.1
websocket-client             1.5.1
websockets                   11.0.1
Werkzeug                     2.2.3
wheel                        0.38.4
widgetsnbextension           4.0.9
wrapt                        1.14.1
xxhash                       3.2.0
yarl                         1.8.2
zipp                         3.15.0
albertfgu commented 10 months ago

Try upgrading triton to 2.1.0

SarthakYadav commented 7 months ago

I'm also facing this issue, and updating to triton doesn't work either. I'm using torch==2.1.1+cu118 and triton==2.1.0