Open giorgio-arena opened 1 year ago
LLVM didn't support bfloat16, so we used our customized type to represent bfloat16. Now that we've upgraded to LLVM head, there could be some improvements now.
Hi @Jokeren thank you for your prompt response! Are you referring to https://github.com/openai/triton/commit/9ef4b5d77315c1f977da3bc9ed528fb2c3ffaa7c? I've tried to run this on the latest commit (https://github.com/openai/triton/commit/a38d2defb8ad6f759c5f422d1a1c9848c6310355 at the time of writing), and it still gives the same result.
No, I meant after https://github.com/openai/triton/commit/9ef4b5d77315c1f977da3bc9ed528fb2c3ffaa7c, we can modify our previous bfloat16 code a bit. Will take a look
Oh right, sorry about the confusion. Thanks!
@giorgio-arena tested it just few hours ago with latest code. It worked with bf16.
@ptillet @daadaada
The problem might be caused by this PR https://github.com/openai/triton/pull/1107web
We've special handled shared_layout
->dot_layout
conversion by treating 2xi16 and 4xi8 as a single i32, but haven't handled mma_layout
->dot_layout
yet.
It's a bit trickier because mma_layout
stores two separated i16s by default. If we concatenate two i16 together, likely we will apply or
and shl
on every element, causing additional overhead.
proof of concept changes. https://github.com/openai/triton/commit/087ad498d8a7cde113ae0f6f59d1f1fbbc41e9ee
Still facing accuracy problems
@Jokeren: just how bad is the accuracy?
It doesn't pass the test... I haven't looked into details, maybe I was wrong on some bit manipulation stuff
We're putting off a few fires now, but we'll look into this more closely once things cool down
Branch has been updated. It might be just a minor precision problem now? You are welcome to verify. I have no idea how precise bf16 should be.
tensor(-0.0002, device='cuda:0', dtype=torch.bfloat16) tensor(-5.1223e-09, device='cuda:0', dtype=torch.bfloat16)
It's not related with https://github.com/openai/triton/pull/1267 because the problem persists even using ptxas -O0
Hi, I'm still getting the same error
python/test/unit/operators/test_flash_attention.py .error: cannot be converted to LLVM IR: missing 'LLVMTranslationDialectInterface' registration for dialect for op: builtin.unrealized_conversion_cast
Even when doing
git fetch && git checkout origin/keren/fix-bf16
Am I doing something wrong? How did you test this?
I cannot reproduce the issue locally. Check you've done the following steps:
rm -rf build
pip install -e .
rm -rf ~/.triton/cache
pip uninstall pytorch-triton -y
If still observing the same error, please copy and paste the generated ttgir under ~/.triton/cache
Hi @Jokeren, thank you for that, you were right, I wasn't rebuilding properly! :) Also, the numerical divergence that I'm getting is not too bad
Arrays are not almost equal to 2 decimals
Mismatched elements: 203 / 12582912 (0.00161%)
Max absolute difference: 0.03125
Could I ask what the status on merging this branch to main
is? Is there any plan to do so anytime soon? Thanks
So @tridao suggested me testing against the original flash attention first, and I will have to confirm with @ptillet on this before the merge
Honestly I think we can just merge. The max divergence seems very reasonable considering that bfloat16 has less mantissa bits than float16
I see, but we have to update the error limits first, right? I'll do that some time later this week.
yeah, or we can reduce the input size. Seems like these errors are pretty unlikely
To make it more consistent with existing test, I only reduced the decimal limit for v
and it passed
I'm just wondering whether llvm still doesn't support the bfloat16 type? I get the same error in another testcase.
I'm just wondering whether llvm still doesn't support the bfloat16 type? I get the same error in another testcase.
testcase for bfloat16 type : python/test/unit/language/test_core.py
Hi, I have enabled bfloat16 testing in Triton (https://github.com/openai/triton/pull/1244/), but I'm getting this error with this data type
Could this get fixed please? Thanks