ROCm / triton

Development repository for the Triton language and compiler
MIT License
92 stars 29 forks source link

use hw for fp8 type conversion #386

Closed scxiao closed 11 months ago

scxiao commented 1 year ago

This PR is to implement the type conversion between fp8 and fp16 using hardware instruction on MI300. The following changes are included: 1) Implemented HW instructions v_cvt_pk_bf8_f32, v_cvt_pk_fp8_f32, v_cvt_pk_f32_fp8, and v_cvt_pk_f32_bf8 for type conversion between fp8 and fp16 on MI300. For fp8 -> fp16, we do fp8 --> fp32 --> f16. For fp16 -> fp8, we do fp16 --> fp32 --> fp8 since no single instruction can do the conversion. 2) Added computeCapability values for different hardware (100 for MI100, 200 for MI200, and 300 for MI300). 3) Added selection of hw instruction or existing bit manipulation approach for type conversion based on the computeCapability of different hardware. In other words, 100/200 selects current approach, 300 select HW instruction. 4) Refined of the processing of different fp8 data types, specifically, a) mixed input data types fp8/fp16, fp8 is converted to fp16 on all CDNA hardware, and use fp16 mfma for matmul b) both inputs are fp8. if both fp8 are AMD fp8 data types (fp8e4b4, fp8e5b16), fp8 mfma instruction is used on MI300, while on MI200 and MI300, convert fp8 to fp16 and fp16 mfma instruction is used. c) both inputs are fp8, but one fp8 is not AMD fp8 type (can only be the standard fp8 data type fp8e5), convert both fp8 to fp16, and fp16 mfma is used on all CDNA GPU. d) for RDNA hard, all input data type are converted to fp32 and call fma for matmul.

Note: the test script test_fp8_vecadd.py will be removed later.