triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.1k stars 1.6k forks source link

RuntimeError: Triton Error [CUDA]: device kernel image is invalid #4390

Open MstarLioning opened 2 months ago

MstarLioning commented 2 months ago

Hello everyone,

I encountered an error message (as shown below) while trying to run the Mamba model (code below).

Experimental environment: Cuda11.8 + Pytorch2.0.0 + Triton=2.2.0

What should I do? (I tried running it on other servers with the same configuration, and it worked fine, so it doesn't seem to be related to the software.)

` import os

import torch from mamba_ssm import Mamba cuda_version = torch.version.cuda

batch, length, dim = 2, 64, 16 x = torch.randn(batch, length, dim).to("cuda:0") model = Mamba(

This module uses roughly 3 expand d_model^2 parameters

d_model=dim,  # Model dimension d_model
d_state=16,  # SSM state expansion factor
d_conv=4,  # Local convolution width
expand=2,  # Block expansion factor

).to("cuda:0") print(model) y = model(x) print("Mamba result", y.shape) assert y.shape == x.shape

import torch from mamba_ssm import Mamba2

batch, length, dim = 2, 64, 1024 x = torch.randn(batch, length, dim).to("cuda:0") model = Mamba2(

This module uses roughly 3 expand d_model^2 parameters

# make sure d_model * expand / headdim = multiple of 8
d_model=dim,  # Model dimension d_model
d_state=64,  # SSM state expansion factor, typically 64 or 128
d_conv=4,  # Local convolution width
expand=2,  # Block expansion factor
headdim=64,  # default 64

).to("cuda:0") print(model) y = model(x) print("Mamba2 result", y.shape) assert y.shape == x.shape `

' /home/zcy/anaconda3/envs/Mamba/bin/python3 /tmp/pycharm_project_313/try.py Mamba( (in_proj): Linear(in_features=16, out_features=64, bias=False) (conv1d): Conv1d(32, 32, kernel_size=(4,), stride=(1,), padding=(3,), groups=32) (act): SiLU() (x_proj): Linear(in_features=32, out_features=33, bias=False) (dt_proj): Linear(in_features=1, out_features=32, bias=True) (out_proj): Linear(in_features=32, out_features=16, bias=False) ) Mamba result torch.Size([2, 64, 16]) Mamba2( (in_proj): Linear(in_features=1024, out_features=4256, bias=False) (conv1d): Conv1d(2176, 2176, kernel_size=(4,), stride=(1,), padding=(3,), groups=2176) (act): SiLU() (norm): RMSNorm() (out_proj): Linear(in_features=2048, out_features=1024, bias=False) ) Traceback (most recent call last): File "/tmp/pycharm_project_313/try.py", line 37, in y = model(x) File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "/tmp/pycharm_project_313/mamba_ssm/modules/mamba2.py", line 185, in forward out = mamba_split_conv1d_scan_combined( File "/tmp/pycharm_project_313/mamba_ssm/ops/triton/ssd_combined.py", line 930, in mamba_split_conv1d_scan_combined return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate) File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply return super().apply(*args, *kwargs) # type: ignore[misc] File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 98, in decorate_fwd return fwd(args, kwargs) File "/tmp/pycharm_project_313/mamba_ssm/ops/triton/ssd_combined.py", line 795, in forward outx, , dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit) File "/tmp/pycharm_project_313/mamba_ssm/ops/triton/ssd_combined.py", line 312, in _mamba_chunk_scan_combined_fwd dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) File "/tmp/pycharm_project_313/mamba_ssm/ops/triton/ssd_chunk_state.py", line 675, in _chunk_cumsum_fwd _chunk_cumsum_fwd_kernel[grid_chunk_cs]( File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in run timings = {config: self._bench(*args, config=config, *kwargs) for config in pruned_configs} File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in timings = {config: self._bench(args, config=config, **kwargs) for config in pruned_configs} File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 122, in _bench return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/triton/testing.py", line 102, in do_bench fn() File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 110, in kernel_call self.fn.run( File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/triton/runtime/jit.py", line 550, in run bin.c_wrapper( File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/triton/compiler/compiler.py", line 692, in getattribute self._init_handles() File "/home/zcy/anaconda3/envs/Mamba/lib/python3.10/site-packages/triton/compiler/compiler.py", line 683, in _init_handles mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device) RuntimeError: Triton Error [CUDA]: device kernel image is invalid '

ispobock commented 2 months ago

Maybe you can ref https://github.com/triton-lang/triton/issues/4172 and https://github.com/InternLM/lmdeploy/pull/1621#issuecomment-2179731554