Greetings! Thanks for your great work! When I tried the benchmark code, I met the error below. Could you please share some possible solutions?
python benchmarks/benchmark_generation_mamba_simple.py --model-name "/home/x/VisionProjects/mamba/ckpts/mamba-130m" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.5
Loading model /home/x/VisionProjects/mamba/ckpts/mamba-130m
Number of parameters: 129135360
Traceback (most recent call last):
File "<string>", line 21, in _layer_norm_fwd_1pass_kernel
KeyError: ('2-.-1-.-0-+-2-c-3-2-f-4-3-9-9-9-83ca8b715a9dc5f32dc1110973485f64-45375ed7aa3bacaed5f41dca33dc8ee0-6590aa19b3e9909e5c8a7254fb3b9328-e6da1445790e1250a9b68f17efc2dd18-7f2d2fed060f2e0fa46ef4e19e20c865-e1f133f98d04093da2078dfc51c36b72-056bca445a91d3175375bc8481ed1689-0db1785b8dc43452c61ef6d926ec11bb-6aff3b6e239e435b817994e60abc8cef', (torch.float16, torch.float16, torch.float16, None, None, torch.float32, None, torch.float32, 'i32', 'i32', 'i32', 'i32', 'i32', 'fp32'), (True, 1024, False, True, False), (True, True, True, (False,), (False,), True, (False,), True, (True, False), (True, False), (True, False), (True, False), (True, False), (False,)), 1, 2, False)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "benchmarks/benchmark_generation_mamba_simple.py", line 77, in <module>
out = fn()
File "benchmarks/benchmark_generation_mamba_simple.py", line 54, in <lambda>
fn = lambda: model.generate(
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 218, in generate
output = decode(
File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 127, in decode
model._decoding_cache = update_graph_cache(
File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 311, in update_graph_cache
cache.callables[batch_size, decoding_seqlen] = capture_graph(
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/utils/generation.py", line 345, in capture_graph
logits = model(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 221, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/models/mixer_seq_simple.py", line 152, in forward
hidden_states, residual = layer(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/modules/mamba_simple.py", line 341, in forward
hidden_states, residual = fused_add_norm_fn(
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/opt/conda/lib/python3.8/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/opt/conda/lib/python3.8/site-packages/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 77, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 77, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 65, in _bench
return do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8))
File "/opt/conda/lib/python3.8/site-packages/triton/testing.py", line 146, in do_bench
fn()
File "/opt/conda/lib/python3.8/site-packages/triton/runtime/autotuner.py", line 63, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "<string>", line 41, in _layer_norm_fwd_1pass_kernel
File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 1687, in compile
return CompiledKernel(fn, so_path, metadata, asm)
File "/opt/conda/lib/python3.8/site-packages/triton/compiler.py", line 1700, in __init__
mod = importlib.util.module_from_spec(spec)
File "<frozen importlib._bootstrap>", line 556, in module_from_spec
File "<frozen importlib._bootstrap_external>", line 1101, in create_module
File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
ImportError: /root/.triton/cache/767259c163b96d4d22c0eea24dd36494/_layer_norm_fwd_1pass_kernel.so: undefined symbol: cuLaunchKernel
Greetings! Thanks for your great work! When I tried the benchmark code, I met the error below. Could you please share some possible solutions?
The dependencies and libraries are shown below: