Greetings! Thanks for your great work! When I tried the benchmark code, I met the error below. Could you please share some possible solutions?

python benchmarks/ --model-name "/home/x/VisionProjects/mamba/ckpts/mamba-130m" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
Loading model /home/x/VisionProjects/mamba/ckpts/mamba-130m
Number of parameters: 129135360
Traceback (most recent call last):
  File "<string>", line 21, in _layer_norm_fwd_1pass_kernel
KeyError: ('2-.-1-.-0-+-2-c-3-2-f-4-3-9-9-9-83ca8b715a9dc5f32dc1110973485f64-45375ed7aa3bacaed5f41dca33dc8ee0-6590aa19b3e9909e5c8a7254fb3b9328-e6da1445790e1250a9b68f17efc2dd18-7f2d2fed060f2e0fa46ef4e19e20c865-e1f133f98d04093da2078dfc51c36b72-056bca445a91d3175375bc8481ed1689-0db1785b8dc43452c61ef6d926ec11bb-6aff3b6e239e435b817994e60abc8cef', (torch.float16, torch.float16, torch.float16, None, None, torch.float32, None, torch.float32, 'i32', 'i32', 'i32', 'i32', 'i32', 'fp32'), (True, 1024, False, True, False), (True, True, True, (False,), (False,), True, (False,), True, (True, False), (True, False), (True, False), (True, False), (True, False), (False,)), 1, 2, False)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "benchmarks/", line 77, in <module>
    out = fn()
  File "benchmarks/", line 54, in <lambda>
    fn = lambda: model.generate(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/", line 218, in generate
    output = decode(
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/", line 127, in decode
    model._decoding_cache = update_graph_cache(
  File "/opt/conda/lib/python3.8/site-packages/torch/utils/", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/", line 311, in update_graph_cache
    cache.callables[batch_size, decoding_seqlen] = capture_graph(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/", line 345, in capture_graph
    logits = model(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/models/", line 221, in forward
    hidden_states = self.backbone(input_ids, inference_params=inference_params)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/models/", line 152, in forward
    hidden_states, residual = layer(
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/modules/", line 341, in forward
    hidden_states, residual = fused_add_norm_fn(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/", line 478, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/opt/conda/lib/python3.8/site-packages/torch/autograd/", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/", line 411, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/", line 155, in _layer_norm_fwd
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/", line 77, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/", line 77, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/", line 65, in _bench
    return do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8))
  File "/opt/conda/lib/python3.8/site-packages/triton/", line 146, in do_bench
  File "/opt/conda/lib/python3.8/site-packages/triton/runtime/", line 63, in kernel_call*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 41, in _layer_norm_fwd_1pass_kernel
  File "/opt/conda/lib/python3.8/site-packages/triton/", line 1687, in compile
    return CompiledKernel(fn, so_path, metadata, asm)
  File "/opt/conda/lib/python3.8/site-packages/triton/", line 1700, in __init__
    mod = importlib.util.module_from_spec(spec)
  File "<frozen importlib._bootstrap>", line 556, in module_from_spec
  File "<frozen importlib._bootstrap_external>", line 1101, in create_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: /root/.triton/cache/767259c163b96d4d22c0eea24dd36494/ undefined symbol: cuLaunchKernel

The dependencies and libraries are shown below:

Try upgrading triton to 2.1.0

I'm also facing this issue, and updating to triton doesn't work either. I'm using torch==2.1.1+cu118 and triton==2.1.0