microsoft / microxcaling

PyTorch emulation library for Microscaling (MX)-compatible data formats
MIT License
165 stars 21 forks source link

How is the matmul for MX format implemented? #23

Closed xijiu9 closed 3 months ago

xijiu9 commented 6 months ago

Thanks for this great project! I have some question about how you implemented Matmul for two MX format matrices.

This repo appears to provide its simulation, but do not provide its actual CUDA implementation. My current question is, how do you implement MX format matmul on CUDA?

For instance, as stated in Section 6 in OCP Microscaling Formats (MX) Specification, you need to perform for loop (1) A(FP8) B(FP8) = C(FP32) (2) C = C scale_of_A * scale_of_B (to dequantize the MX format) (3) add C to the accumulator. If you do it this way, can your CUDA kernel be faster than FP16? Or does the current CUDA (such as PTX instructions) support fusing the scaling and dequantize step with the MMA operation?

rensushan commented 5 months ago

In my opinion, this project use "fack quantization". It means that the quantized value are dequantized to FP16/32 precision agian, and it could make use of CUDA library to perform GEMM /CONV operation. In fact , this library is to verify the numerical result of MX format, but not perform the clock-cycle simulation of MX arithematics.