microsoft / microxcaling

PyTorch emulation library for Microscaling (MX)-compatible data formats
MIT License
123 stars 14 forks source link

How is the matmul for MX format implemented? #23

Open xijiu9 opened 1 month ago

xijiu9 commented 1 month ago

Thanks for this great project! I have some question about how you implemented Matmul for two MX format matrices.

This repo appears to provide its simulation, but do not provide its actual CUDA implementation. My current question is, how do you implement MX format matmul on CUDA?

For instance, as stated in Section 6 in OCP Microscaling Formats (MX) Specification, you need to perform for loop (1) A(FP8) B(FP8) = C(FP32) (2) C = C scale_of_A * scale_of_B (to dequantize the MX format) (3) add C to the accumulator. If you do it this way, can your CUDA kernel be faster than FP16? Or does the current CUDA (such as PTX instructions) support fusing the scaling and dequantize step with the MMA operation?

rensushan commented 1 day ago

In my opinion, this project use "fack quantization". It means that the quantized value are dequantized to FP16/32 precision agian, and it could make use of CUDA library to perform GEMM /CONV operation. In fact , this library is to verify the numerical result of MX format, but not perform the clock-cycle simulation of MX arithematics.