ROCm / triton

Development repository for the Triton language and compiler
MIT License
92 stars 29 forks source link

Remove redundant fp32->fp16 conversion in FA #349

Closed oplavsic closed 1 year ago

oplavsic commented 1 year ago

Reverse change introduced by upstream commit https://github.com/openai/triton/commit/5162871c6cae01a8508a309cf21a8e6b68a4c091. This commit converts result of first dot from fp32 to fp16, but immediately after, result is converted back to fp32 for softmax. Triton IR generated: %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #mfma> %94 = arith.truncf %93 : tensor<128x64xf32, #mfma> to tensor<128x64xf16, #mfma> %95 = arith.addf %94, %cst_0 : tensor<128x64xf16, #mfma> %96 = arith.extf %95 : tensor<128x64xf16, #mfma> to tensor<128x64xf32, #mfma>

TODO: Investigate whether we can perform softmax in fp16.