state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
11.86k stars 986 forks source link

Triton Error [CUDA]: device kernel image is invalid #386

Open rationalspark opened 1 month ago

rationalspark commented 1 month ago

Thanks for the wonderful work.

When running Mamba2, I encountered the error "Triton Error [CUDA]: device kernel image is invalid".

Should you be so kind as to provide some advice?

My environment is torch 2.3.0+cu118 triton 2.3.0 The GPU is RTX3090.

The code is

''' import torch from mamba_ssm import Mamba2 batch, length, dim = 2, 512, 256 x = torch.randn(batch, length, dim).to("cuda:2") model = Mamba2(

This module uses roughly 3 expand d_model^2 parameters

d_model=dim, # Model dimension d_model
d_state=64,  # SSM state expansion factor, typically 64 or 128
d_conv=4,    # Local convolution width
expand=2,    # Block expansion factor
headdim=128

).to("cuda:2") y = model(x) assert y.shape == x.shape '''

The error messages are


RuntimeError Traceback (most recent call last) Cell In[3], line 13 4 x = torch.randn(batch, length, dim).to("cuda:2") 5 model = Mamba2( 6 # This module uses roughly 3 expand d_model^2 parameters 7 d_model=dim, # Model dimension d_model (...) 11 headdim=128 12 ).to("cuda:2") ---> 13 y = model(x) 14 assert y.shape == x.shape

File ~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, kwargs) 1530 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(args, kwargs)

File ~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, *kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(args, **kwargs) 1543 try: 1544 result = None

File ~/work/ts/mamba/mamba_ssm/modules/mamba2.py:176, in Mamba2.forward(self, u, seqlen, seq_idx, inference_params) 174 dt_limit_kwargs = {} if self.dt_limit == (0.0, float("inf")) else dict(dt_limit=self.dt_limit) 175 if self.use_mem_eff_path and inference_params is None: --> 176 out = mamba_split_conv1d_scan_combined( 177 zxbcdt, 178 rearrange(self.conv1d.weight, "d 1 w -> d w"), 179 self.conv1d.bias, 180 self.dt_bias, 181 A, 182 D=rearrange(self.D, "(h p) -> h p", p=self.headdim) if self.D_has_hdim else self.D, 183 chunk_size=self.chunk_size, 184 seq_idx=seq_idx, 185 activation=self.activation, 186 rmsnorm_weight=self.norm.weight if self.rmsnorm else None, 187 rmsnorm_eps=self.norm.eps if self.rmsnorm else 1e-6, 188 outproj_weight=self.out_proj.weight, 189 outproj_bias=self.out_proj.bias, 190 headdim=None if self.D_has_hdim else self.headdim, 191 ngroups=self.ngroups, 192 norm_before_gate=self.norm_before_gate, 193 **dt_limit_kwargs, 194 ) 195 if seqlen_og is not None: 196 out = rearrange(out, "b l d -> (b l) d")

File ~/work/ts/mamba/mamba_ssm/ops/triton/ssd_combined.py:908, in mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate) 889 def mamba_split_conv1d_scan_combined(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu", rmsnorm_weight=None, rmsnorm_eps=1e-6, outproj_weight=None, outproj_bias=None, headdim=None, ngroups=1, norm_before_gate=True): 890 """ 891 Argument: 892 zxbcdt: (batch, seqlen, 2 dim + 2 ngroups dstate + nheads) where dim == nheads headdim (...) 906 out: (batch, seqlen, dim) 907 """ --> 908 return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)

File ~/anaconda3/lib/python3.9/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, *kwargs) 595 if not torch._C._are_functorch_transforms_active(): 596 # See NOTE: [functorch vjp and autograd interaction] 597 args = _functorch.utils.unwrap_dead_wrappers(args) --> 598 return super().apply(args, **kwargs) # type: ignore[misc] 600 if not is_setup_ctx_defined: 601 raise RuntimeError( 602 "In order to use an autograd.Function with functorch transforms " 603 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context " 604 "staticmethod. For more details, please see " 605 "https://pytorch.org/docs/master/notes/extending.func.html" 606 )

File ~/anaconda3/lib/python3.9/site-packages/torch/cuda/amp/autocast_mode.py:115, in custom_fwd..decorate_fwd(*args, *kwargs) 113 if cast_inputs is None: 114 args[0]._fwd_used_autocast = torch.is_autocast_enabled() --> 115 return fwd(args, **kwargs) 116 else: 117 autocast_context = torch.is_autocast_enabled()

File ~/work/ts/mamba/mamba_ssm/ops/triton/ssd_combined.py:773, in MambaSplitConv1dScanCombinedFn.forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate) 771 out = torch.cat([_swiglu_fwd(zx0), out], dim=-1) 772 else: --> 773 outx, , dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit) 774 # reshape input data into 2D tensor 775 x_rms = rearrange(out_x, "b s h p -> (b s) (h p)")

File ~/work/ts/mamba/mamba_ssm/ops/triton/ssd_combined.py:307, in _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit) 302 assert initial_states.shape == (batch, nheads, headdim, dstate) 303 # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size) 304 # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) 305 # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) 306 # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus) --> 307 dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit) 308 states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True) 309 # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True) 310 # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True) 311 # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)

File ~/work/ts/mamba/mamba_ssm/ops/triton/ssd_chunk_state.py:582, in _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias, dt_softplus, dt_limit) 580 grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H'])) 581 with torch.cuda.device(dt.device.index): --> 582 _chunk_cumsum_fwd_kernel[grid_chunk_cs]( 583 dt, A, dt_bias, dt_out, dA_cumsum, 584 batch, seqlen, nheads, chunk_size, 585 dt_limit[0], dt_limit[1], 586 dt.stride(0), dt.stride(1), dt.stride(2), 587 A.stride(0), 588 dt_bias.stride(0) if dt_bias is not None else 0, 589 dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3), 590 dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3), 591 dt_softplus, 592 HAS_DT_BIAS=dt_bias is not None, 593 BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), 594 ) 595 return dA_cumsum, dt_out

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/jit.py:167, in KernelInterface.getitem..(*args, kwargs) 161 def getitem(self, grid) -> T: 162 """ 163 A JIT function is launched with: fn[grid](*args, *kwargs). 164 Hence JITFunction.getitem returns a callable proxy that 165 memorizes the grid. 166 """ --> 167 return lambda args, kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/autotuner.py:143, in Autotuner.run(self, *args, *kwargs) 141 pruned_configs = self.prune_configs(kwargs) 142 bench_start = time.time() --> 143 timings = {config: self._bench(args, config=config, **kwargs) for config in pruned_configs} 144 bench_end = time.time() 145 self.bench_time = bench_end - bench_start

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/autotuner.py:143, in (.0) 141 pruned_configs = self.prune_configs(kwargs) 142 bench_start = time.time() --> 143 timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} 144 bench_end = time.time() 145 self.bench_time = bench_end - bench_start

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/autotuner.py:122, in Autotuner._bench(self, config, *args, **meta) 119 self.post_hook(args) 121 try: --> 122 return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) 123 except OutOfResources: 124 return [float("inf"), float("inf"), float("inf")]

File ~/anaconda3/lib/python3.9/site-packages/triton/testing.py:102, in do_bench(fn, warmup, rep, grad_to_none, quantiles, fast_flush, return_mode) 83 import torch 84 """ 85 Benchmark the runtime of the provided function. By default, return the median runtime of :code:fn along with 86 the 20-th and 80-th performance percentile. (...) 99 :type fast_flush: bool 100 """ --> 102 fn() 103 torch.cuda.synchronize() 105 # We maintain a buffer of 256 MB that we clear 106 # before each kernel call to make sure that the L2 107 # doesn't contain any input data before the run

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/autotuner.py:110, in Autotuner._bench..kernel_call() 108 config.pre_hook(full_nargs) 109 self.pre_hook(args) --> 110 self.fn.run( 111 *args, 112 num_warps=config.num_warps, 113 num_stages=config.num_stages, 114 num_ctas=config.num_ctas, 115 enable_warp_specialization=config.enable_warp_specialization, 116 # enable_persistent=False, 117 **current, 118 ) 119 self.post_hook(args)

File ~/anaconda3/lib/python3.9/site-packages/triton/runtime/jit.py:425, in JITFunction.run(self, grid, warmup, *args, *kwargs) 423 if not warmup: 424 args = [arg.value for arg in args if not arg.param.is_constexpr] --> 425 kernel.run(grid_0, grid_1, grid_2, kernel.num_warps, kernel.num_ctas, # number of warps/ctas per instance 426 kernel.cluster_dims[0], kernel.cluster_dims[1], kernel.cluster_dims[2], # cluster 427 kernel.shared, stream, kernel.function, CompiledKernel.launch_enter_hook, 428 CompiledKernel.launch_exit_hook, kernel, 429 driver.assemble_tensormap_to_arg(kernel.metadata["tensormaps_info"], args)) 430 return kernel

File ~/anaconda3/lib/python3.9/site-packages/triton/compiler/compiler.py:255, in CompiledKernel.getattribute(self, name) 253 def getattribute(self, name): 254 if name == 'run': --> 255 self._init_handles() 256 return super().getattribute(name)

File ~/anaconda3/lib/python3.9/site-packages/triton/compiler/compiler.py:250, in CompiledKernel._init_handles(self) 248 raise OutOfResources(self.shared, max_shared, "shared memory") 249 # TODO: n_regs, n_spills should be metadata generated when calling ptxas --> 250 self.module, self.function, self.n_regs, self.n_spills = driver.utils.load_binary( 251 self.name, self.kernel, self.shared, device)

RuntimeError: Triton Error [CUDA]: device kernel image is invalid

Thank you very much for all your assistance.

tridao commented 1 month ago

It's a triton error, idk how to fix it but you can search triton repo issues

catalpaaa commented 1 month ago

try to build mamba and causal conv1d yourself with pip install -e, iirc

or maybe build triton from source

jsie7 commented 1 month ago

I encountered this before as well. It seems as if mamba/triton is using binaries for a different cuda version, hence the invalidity error. Either build from source to ensure the correct version or manually select the correct binaries.

catalpaaa commented 1 month ago

I encountered this before as well. It seems as if mamba/triton is using binaries for a different cuda version, hence the invalidity error. Either build from source to ensure the correct version or manually select the correct binaries.

Yes, check if your CUDA_HOME is pointing to other cuda installations

JHChen1 commented 1 month ago

我以前也遇到过这种情况。似乎 mamba/triton 使用的是不同 cuda 版本的二进制文件,因此出现无效错误。要么从源代码构建以确保正确的版本,要么手动选择正确的二进制文件。

Hello, I am currently using torch==2.0.1; triton==2.3.0; mamba_ssm==2.0.3. I am using cuda v11.8 V100 and have this problem: it can run correctly on cuda:0, but reports an error on cuda:1: "untimeError: Triton Error [CUDA]: context is destroyed". Can you give me some advice?

tridao commented 1 month ago

You can try upgrading pytorch, though I don't think Triton support V100 very well in general

rationalspark commented 1 month ago

Thanks for all the replies. I tried to build Mamba from source, or install the latest triton 2.3.0, but the "Triton Error [CUDA]: device kernel image is invalid" still exists. I also tried to build Trition from souce, but the compilation fails.