Closed xijiu9 closed 3 months ago
In my opinion, this project use "fack quantization". It means that the quantized value are dequantized to FP16/32 precision agian, and it could make use of CUDA library to perform GEMM /CONV operation. In fact , this library is to verify the numerical result of MX format, but not perform the clock-cycle simulation of MX arithematics.
Thanks for this great project! I have some question about how you implemented Matmul for two MX format matrices.
This repo appears to provide its simulation, but do not provide its actual CUDA implementation. My current question is, how do you implement MX format matmul on CUDA?
For instance, as stated in Section 6 in OCP Microscaling Formats (MX) Specification, you need to perform for loop (1) A(FP8) B(FP8) = C(FP32) (2) C = C scale_of_A * scale_of_B (to dequantize the MX format) (3) add C to the accumulator. If you do it this way, can your CUDA kernel be faster than FP16? Or does the current CUDA (such as PTX instructions) support fusing the scaling and dequantize step with the MMA operation?