ROCm / triton

Development repository for the Triton language and compiler
MIT License
86 stars 27 forks source link

fp8 type support #357

Closed scxiao closed 10 months ago

scxiao commented 11 months ago

This PR is to add amd fp8 data type to triton, which includes the following changes: 1) Added two amd fp8 data formats (fp8e4b8 and fp8e5b16) to triton, so triton thinks these types legal. 2) Provided type conversions between fp8 and fp16 and vice versa. 3) Changed FlashAttention to switch between fp16 and fp8 flexible by a switch input type.

Changes to gemm to support different types are in PR https://github.com/ROCmSoftwarePlatform/triton/pull/373.

scxiao commented 10 months ago

Hi all, when you get a chance, could you please review this PR? I verified the execution on both mi200 and mi300, with and without pytorch support for fp8, all can pass.

zhanglx13 commented 10 months ago

Sure. I'll take a look some time today.

scxiao commented 10 months ago

LGTM Can you add some descriptions for this PR. In particular make it clear that FA now supports fp8 for q and k only.

Thanks. Done.