triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
13.28k stars 1.63k forks source link

Flash Attention test fails for bfloat16 #1245

Open giorgio-arena opened 1 year ago

giorgio-arena commented 1 year ago

Hi, I have enabled bfloat16 testing in Triton (https://github.com/openai/triton/pull/1244/), but I'm getting this error with this data type

giorgio@giorgio:triton$ pytest python/test/unit/operators/test_flash_attention.py -s
========================================================================================== test session starts ==========================================================================================
platform linux -- Python 3.10.9, pytest-7.2.1, pluggy-1.0.0
rootdir: /usr/local/home/giorgio/triton/python
collected 2 items                                                                                                                                                                                       

python/test/unit/operators/test_flash_attention.py .error: cannot be converted to LLVM IR: missing `LLVMTranslationDialectInterface` registration for dialect for op: builtin.unrealized_conversion_cast
Failed to emit LLVM IR
Translate to LLVM IR failedLLVM ERROR: Failed to translate TritonGPU to LLVM IR.
Fatal Python error: Aborted

Thread 0x00007fbc757fe6c0 (most recent call first):
  <no Python frame>

Current thread 0x00007fbd053ae200 (most recent call first):
  File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1018 in ttgir_to_llir
  File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1570 in <lambda>
  File "/usr/local/home/giorgio/triton/python/triton/compiler.py", line 1637 in compile
  File "<string>", line 41 in _fwd_kernel
  File "/usr/local/home/giorgio/triton/python/triton/ops/flash_attention.py", line 214 in forward
  File "/usr/local/home/giorgio/triton/python/test/unit/operators/test_flash_attention.py", line 33 in test_op
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/python.py", line 195 in pytest_pyfunc_call
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/python.py", line 1789 in runtest
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 167 in pytest_runtest_call
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 260 in <lambda>
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 339 in from_call
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 259 in call_runtest_hook
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 220 in call_and_report
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 131 in runtestprotocol
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/runner.py", line 112 in pytest_runtest_protocol
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 349 in pytest_runtestloop
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 324 in _main
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 270 in wrap_session
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/main.py", line 317 in pytest_cmdline_main
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_callers.py", line 39 in _multicall
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_manager.py", line 80 in _hookexec
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/pluggy/_hooks.py", line 265 in __call__
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/config/__init__.py", line 167 in main
  File "/usr/local/home/giorgio/.local/lib/python3.10/site-packages/_pytest/config/__init__.py", line 190 in console_main
  File "/usr/local/home/giorgio/.local/bin/pytest", line 8 in <module>

Extension modules: torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg.lapack_lite, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, __triton_launcher, cuda_utils (total: 23)
Aborted

Could this get fixed please? Thanks

Jokeren commented 1 year ago

LLVM didn't support bfloat16, so we used our customized type to represent bfloat16. Now that we've upgraded to LLVM head, there could be some improvements now.

giorgio-arena commented 1 year ago

Hi @Jokeren thank you for your prompt response! Are you referring to https://github.com/openai/triton/commit/9ef4b5d77315c1f977da3bc9ed528fb2c3ffaa7c? I've tried to run this on the latest commit (https://github.com/openai/triton/commit/a38d2defb8ad6f759c5f422d1a1c9848c6310355 at the time of writing), and it still gives the same result.

Jokeren commented 1 year ago

No, I meant after https://github.com/openai/triton/commit/9ef4b5d77315c1f977da3bc9ed528fb2c3ffaa7c, we can modify our previous bfloat16 code a bit. Will take a look

giorgio-arena commented 1 year ago

Oh right, sorry about the confusion. Thanks!

linxihui commented 1 year ago

@giorgio-arena tested it just few hours ago with latest code. It worked with bf16.

Jokeren commented 1 year ago

@ptillet @daadaada The problem might be caused by this PR https://github.com/openai/triton/pull/1107web We've special handled shared_layout->dot_layout conversion by treating 2xi16 and 4xi8 as a single i32, but haven't handled mma_layout->dot_layout yet.

It's a bit trickier because mma_layout stores two separated i16s by default. If we concatenate two i16 together, likely we will apply or and shl on every element, causing additional overhead.

Jokeren commented 1 year ago

proof of concept changes. https://github.com/openai/triton/commit/087ad498d8a7cde113ae0f6f59d1f1fbbc41e9ee

Still facing accuracy problems

sbodenstein commented 1 year ago

@Jokeren: just how bad is the accuracy?

Jokeren commented 1 year ago

It doesn't pass the test... I haven't looked into details, maybe I was wrong on some bit manipulation stuff

ptillet commented 1 year ago

We're putting off a few fires now, but we'll look into this more closely once things cool down

Jokeren commented 1 year ago

Branch has been updated. It might be just a minor precision problem now? You are welcome to verify. I have no idea how precise bf16 should be.

tensor(-0.0002, device='cuda:0', dtype=torch.bfloat16) tensor(-5.1223e-09, device='cuda:0', dtype=torch.bfloat16)
Jokeren commented 1 year ago

It's not related with https://github.com/openai/triton/pull/1267 because the problem persists even using ptxas -O0

giorgio-arena commented 1 year ago

Hi, I'm still getting the same error python/test/unit/operators/test_flash_attention.py .error: cannot be converted to LLVM IR: missing 'LLVMTranslationDialectInterface' registration for dialect for op: builtin.unrealized_conversion_cast Even when doing git fetch && git checkout origin/keren/fix-bf16 Am I doing something wrong? How did you test this?

Jokeren commented 1 year ago

I cannot reproduce the issue locally. Check you've done the following steps:

rm -rf build
pip install -e .
rm -rf ~/.triton/cache
pip uninstall pytorch-triton -y

If still observing the same error, please copy and paste the generated ttgir under ~/.triton/cache

giorgio-arena commented 1 year ago

Hi @Jokeren, thank you for that, you were right, I wasn't rebuilding properly! :) Also, the numerical divergence that I'm getting is not too bad

Arrays are not almost equal to 2 decimals

Mismatched elements: 203 / 12582912 (0.00161%)
Max absolute difference: 0.03125

Could I ask what the status on merging this branch to main is? Is there any plan to do so anytime soon? Thanks

Jokeren commented 1 year ago

So @tridao suggested me testing against the original flash attention first, and I will have to confirm with @ptillet on this before the merge

ptillet commented 1 year ago

Honestly I think we can just merge. The max divergence seems very reasonable considering that bfloat16 has less mantissa bits than float16

Jokeren commented 1 year ago

I see, but we have to update the error limits first, right? I'll do that some time later this week.

ptillet commented 1 year ago

yeah, or we can reduce the input size. Seems like these errors are pretty unlikely

Jokeren commented 1 year ago

To make it more consistent with existing test, I only reduced the decimal limit for v and it passed

Alon-Lau commented 6 months ago

I'm just wondering whether llvm still doesn't support the bfloat16 type? I get the same error in another testcase.

Alon-Lau commented 6 months ago

I'm just wondering whether llvm still doesn't support the bfloat16 type? I get the same error in another testcase.

testcase for bfloat16 type : python/test/unit/language/test_core.py