Hprairie / Bi-Mamba2

A Triton Kernel for incorporating Bi-Directionality in Mamba2
48 stars 1 forks source link

About Triton Error #4

Open lewandofskee opened 2 months ago

lewandofskee commented 2 months ago

Thank you for your excellent work, but I am having the following problem with reproduction:

File "/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py", line 561, in forward
out, out_x, dt_out, dA_cumsum_f, dA_cumsum_b, states_f, states_b, final_states_f, final_states_b = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py", line 323, in _mamba_chunk_scan_combined_fwd
dA_cumsum_f, dA_cumsum_b, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_state.py", line 710, in _chunk_cumsum_fwd
_chunk_cumsum_fwd_kernel[grid_chunk_cs](
File "/opt/anaconda3/lib/python3.10/site-packages/triton/runtime/jit.py", line 345, in
return lambda *args, kwargs: self.run(grid=grid, warmup=False, *args, *kwargs)
File "/opt/anaconda3/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in run
timings = {config: self._bench(
args, config=config,
kwargs) for config in pruned_configs}
File "/opt/anaconda3/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 156, in
timings = {config: self._bench(*args, config=config, *kwargs) for config in pruned_configs}
File "/opt/anaconda3/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 133, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/opt/anaconda3/lib/python3.10/site-packages/triton/testing.py", line 103, in do_bench
fn()
File "/opt/anaconda3/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 114, in kernel_call
self.fn.run(
File "/opt/anaconda3/lib/python3.10/site-packages/triton/runtime/jit.py", line 662, in run
kernel = self.compile(
File "/opt/anaconda3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 276, in compile
module = src.make_ir(options, codegen_fns, context)
File "/opt/anaconda3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 113, in make_ir
return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
triton.compiler.errors.CompilationError: at 34:23:
dt_out_ptr += pid_b
stride_dt_out_batch + pid_c stride_dt_out_chunk
dA_cumsum_f_ptr += pid_b
stride_dA_cs_f_batch + pid_c stride_dA_cs_f_chunk
dA_cumsum_b_ptr += pid_b
stride_dA_cs_b_batch + pid_c * stride_dA_cs_b_chunk

offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)                                                                                                                                                                                                                                   
offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)                                                                                                                                                                                                                                                      
dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)                                                                                                                                                                                                   
A_ptrs = A_ptr + offs_h * stride_A_head                                                                                                                                                                                                                                                      
dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)                                                                                                                                                                                    
dA_cs_f_ptrs = dA_cumsum_f_ptr + (offs_h[:, None] * stride_dA_cs_f_head + offs_c[None, :] * stride_dA_cs_f_csize)                                                                                                                                                                            
dA_cs_b_ptrs = dA_cumsum_b_ptr + (offs_h[:, None] * stride_dA_cs_b_head + offs_c[None, :] * stride_dA_cs_b_csize)                                                                                                                                                                            
chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)                                                                                                                                                                                                                              
                   ^                                                                                                                                                                                                                                                                         

Traceback (most recent call last):
File "/opt/anaconda3/lib/python3.10/site-packages/triton/language/core.py", line 35, in wrapper
return fn(*args, **kwargs)
File "/opt/anaconda3/lib/python3.10/site-packages/triton/language/core.py", line 1875, in minimum
return semantic.minimum(x, y, propagate_nan, _builder)
File "/opt/anaconda3/lib/python3.10/site-packages/triton/language/semantic.py", line 258, in minimum
x, y = binary_op_type_checking_impl(x, y, builder)
File "/opt/anaconda3/lib/python3.10/site-packages/triton/language/semantic.py", line 120, in binary_op_type_checking_impl
check_ptr_type_impl(rhs_sca_ty, lhs_sca_ty, allow_rhs_ptr)
File "/opt/anaconda3/lib/python3.10/site-packages/triton/language/semantic.py", line 102, in check_ptr_type_impl
raise IncompatibleTypeErrorImpl(type_a, type_b)
triton.language.semantic.IncompatibleTypeErrorImpl: invalid operands of type pointer and triton.language.int32

lewandofskee commented 2 months ago

Also, I am having the following problem with the Bi-Directional Kernel instance:

loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_chunk_scan.py":138:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py":179:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py":179:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py":179:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py":179:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py":179:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py":179:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py":179:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py":179:11): error: operation scheduled before its operands loc("/opt/anaconda3/lib/python3.10/site-packages/ssd/bi/ssd_combined.py":179:11): error: operation scheduled before its operands

lewandofskee commented 2 months ago

My environment is: triton 3.0.0 pytorch 2.2.2 python 3.10 cuda 11.8

I would appreciate it if you could reply as soon as possible.

Hprairie commented 2 months ago

For your first error, my guess is that you need to use a more recent version of triton. I developed this using 3.0.0 so maybe try a fresh installation to get the newest version.

For the second error, it only happens when you compile a triton kernel for the first time. It had no effect on me as far as I know, so it should be fine.