Closed akakakakakaa closed 1 year ago
same problem, does anyone solve it?
Got the same problem, the line number reported by triton compiler may be unrelated to this error, since I even dont pass bias to this kernel.
I looked into verifySameEncoding
https://github.com/openai/triton/blob/6dee55c912ad8320ebc63e69a77bb83c19d1f19e/lib/Dialect/Triton/IR/Traits.cpp#L8
but did't figure out what is the encoding of RankedTensorType.
is this still happening? I suspect this has been fixed, I'm not able to run the example you sent, could you please include the full example or the IR before the pass causing the problem.
@ThomasRaoux Can you repro the issue from this PR: https://github.com/Dao-AILab/flash-attention/pull/458 ?
what command line do you run? running python flash_attn_triton.py
nothing happens
@ThomasRaoux ~pytest -q -s tests/test_flash_attn.py
will run it through the gpt-2 model tests, benchmark will run the external triton implementation if that is what you are asking.
https://github.com/Dao-AILab/flash-attention/blob/866a9d33f9bcab0742d007c720ade1e1b79d1d79/benchmarks/benchmark_flash_attention.py#L81~
Let me ask around.
The problem is most likely fixed with https://github.com/openai/triton/commit/d4644d6cb3ae674e1f15932cac1f28104795744f. Can you try to see if it repro with top of tree? If you are able to repro, sharing the IR might be the simplest way for me or someone else to look at it.
I tested with triton_nightly-2.1.0.dev20230822000928-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl, and this "triton_gpu.cmpf" it is fixed.
I found that if I comment out these two lines, the code can run and the calculation result is correct.
tl.store(t_ptrs, acc_o_scale)
acc_o_scale = tl.load(t_ptrs)
I tried to run FlashAttention(not version 2) triton version source code by changing only
tl.dot(q, k, trans_b=True)
totl.dot(q, tl.trans(k))
.But I received error,
This is full source code for forward (only change
tl.dot(q, k, trans_b=True)
totl.dot(q, tl.trans(k))
from flash attention source code)stack trace is here.
Can you advise me how to debug triton code? I can't find any problem in 222 line source code.
Thanks.