triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.35k stars 1.64k forks source link

It raises error when I run 06-fused-attention.py #4200

Open Godlovecui opened 4 months ago

Godlovecui commented 4 months ago

When I run 06-fused-attention.py on RTX 4090, it raises below error. How to fix it? Thank you! triton version: 2.3.0 cuda: 12.4

root@GPU-RTX4090-4-8:/workspaces/triton/python/tutorials# python 06-fused-attention.py Traceback (most recent call last): File "/workspaces/triton/python/tutorials/06-fused-attention.py", line 81, in configs = [ File "/workspaces/triton/python/tutorials/06-fused-attention.py", line 85, in for s in ([1] if is_hip() else [3, 4, 7])\ File "/workspaces/triton/python/tutorials/06-fused-attention.py", line 22, in is_hip return triton.runtime.driver.active.get_current_target().backend == "hip" File "/usr/local/lib/python3.10/dist-packages/triton/runtime/driver.py", line 210, in getattr return getattr(self._obj, name) AttributeError: 'CudaDriver' object has no attribute 'active' root@GPU-RTX4090-4-8:/workspaces/triton/python/tutorials# python 06-fused-attention.py Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 1222, in ast_to_ttir generator.visit(fn.parse()) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 1105, in visit ret = super().visit(node) File "/usr/lib/python3.10/ast.py", line 418, in visit return visitor(node) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 303, in visit_Module ast.NodeVisitor.generic_visit(self, node) File "/usr/lib/python3.10/ast.py", line 426, in generic_visit self.visit(item) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 1105, in visit ret = super().visit(node) File "/usr/lib/python3.10/ast.py", line 418, in visit return visitor(node) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 376, in visit_FunctionDef self.visit_compound_statement(node.body) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement ret_type = self.visit(stmt) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 1105, in visit ret = super().visit(node) File "/usr/lib/python3.10/ast.py", line 418, in visit return visitor(node) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 428, in visit_Assign values = self.visit(node.value) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 1105, in visit ret = super().visit(node) File "/usr/lib/python3.10/ast.py", line 418, in visit return visitor(node) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 1027, in visit_Call return fn(args, extra_kwargs, kws) File "/usr/local/lib/python3.10/dist-packages/triton/language/core.py", line 27, in wrapper return fn(args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/language/core.py", line 1133, in make_block_ptr return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder) File "/usr/local/lib/python3.10/dist-packages/triton/language/semantic.py", line 1531, in make_block_ptr assert sorted(order) == list(range(len(order))), "Expected a permutation of (0, 1, ..., len(order)-1) in order" AssertionError: Expected a permutation of (0, 1, ..., len(order)-1) in order

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/workspaces/triton/python/tutorials/06-fused-attention.py", line 642, in bench_flash_attention.run(save_path=".", print_data=True) File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 341, in run result_dfs.append(self._run(bench, save_path, show_plots, print_data, kwargs)) File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 287, in _run ret = self.fn(x_args, {bench.line_arg: y}, bench.args, kwrags) File "/workspaces/triton/python/tutorials/06-fused-attention.py", line 622, in bench_flash_attention ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 102, in do_bench fn() File "/workspaces/triton/python/tutorials/06-fused-attention.py", line 617, in fn = lambda: attention(q, k, v, causal, sm_scale) File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 598, in apply return super().apply(*args, *kwargs) # type: ignore[misc] File "/workspaces/triton/python/tutorials/06-fused-attention.py", line 459, in forward _attn_fwd[grid]( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 167, in return lambda args, kwargs: self.run(grid=grid, warmup=False, *args, kwargs) File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in run timings = {config: self._bench(*args, config=config, *kwargs) for config in pruned_configs} File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 143, in timings = {config: self._bench(args, config=config, kwargs) for config in pruned_configs} File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 122, in _bench return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8)) File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 102, in do_bench fn() File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 110, in kernel_call self.fn.run( File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 416, in run self.cache[device][key] = compile( File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 191, in compile module = src.make_ir(options) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 117, in make_ir return ast_to_ttir(self.fn, self, options=options) File "/usr/local/lib/python3.10/dist-packages/triton/compiler/code_generator.py", line 1231, in ast_to_ttir raise CompilationError(fn.src, node, repr(e)) from e triton.compiler.errors.CompilationError: at 35:14: offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, HEAD_DIM), order=(1, 0), ) v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0) V_block_ptr = tl.make_block_ptr( base=V + qvk_offset, shape=(N_CTX, HEAD_DIM), strides=(stride_vk, stride_vn), offsets=(0, 0), block_shape=(BLOCK_N, HEAD_DIM), order=v_order, ^ AssertionError('Expected a permutation of (0, 1, ..., len(order)-1) in order') @kashif @ezyang @nelhage @ashay

agshar96 commented 4 months ago

Are you using pip install triton? My errors got fixed when I tried the nightly build