triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.53k stars 1.52k forks source link

h100 fp8 gemm with fp16-to-fp8 casting from the load make the performance bad #2781

Open stephen-youn opened 9 months ago

stephen-youn commented 9 months ago

Hi, When i tried the fp8 gemm code in matmul.py to cast the input "a" to be float16 but casted to fp8 just before the dot product op by setting AB_DTYPE to be tl.float8e4nv (link: https://github.com/openai/triton/blob/addd94e4a8d4dc0beefd5df6d97d57d18065436a/python/triton/ops/matmul.py#L128C1-L130C31), the throughput got really bad.

why is it and what would be the work-around? thanks

stephen-youn commented 9 months ago

when taking A in fp16 and B in fp8 and casting only A into fp8 before the doc-product, it gave this error

unimplemented code path UNREACHABLE executed at /home/xxx/triton/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp:369!

snarayan21 commented 4 months ago

With the latest triton nightly, I'm also running into this issue when casting bf16 inputs to fp8 right before tl.dot. I'm setting AB_DTYPE to tl.float8e4nv before calling the matmul kernel and am running into the same error above. There seems to be another somewhat similar issue here, is this related? I see that @Jokeren you addressed that issue?