microsoft / microxcaling

PyTorch emulation library for Microscaling (MX)-compatible data formats
MIT License
123 stars 14 forks source link

957 is quantized as 896 #4

Closed zhuango closed 8 months ago

zhuango commented 8 months ago

I was working on testing MXFP8_E4M3, and I find that 957 is being quantized as 896 through follow api call:

    d = torch.from_numpy(np.array([[ 957.0000,  957.0000,  902.4000,  960.0000,  832.0000,  291.7882,
         -124.8256,  783.5460,  927.3255, -233.1170,  583.4501,   57.7898,
          136.0891,  851.1933, -857.9279, -825.7414, -959.5632,  665.2397,
          556.3135,  740.0243,  957.2367,  598.3171,  -77.0413,  561.0583,
         -763.4512,  279.8420, -713.2934,  889.3378,   43.6966, -170.6761,
         -470.8888,  548.4674]]))
    y1 = _quantize_mx(x1, 8, 'fp8_e4m3',
                      block_size=32,
                      axes=[-1],
                      round='even',
                      flush_fp32_subnorms=False,
                      custom_cuda=False)

957 in fp32 format has an exponent of 9, a mantissa of 0.869140625 and a hidden bit 1, 2**9 * (1 + 0.869140625) = 957. So, in order to convert it to MXFP8_E4M3, which have format ofv = X * (-1)**S * 2**(E-bias) * ( 1 + 2**(-m)*M ), we have

    X = 2**9
    S = 0
    E = 7
    bias = 7
    M = 6+0.953125
    m = 3

where X: power-of-two shared scale S: sign bit E: exponent representation bias: fp8 bias m: mantissa bitwidth of fp M: mantissa representation

We can also get the fp32 number through 2**9 * (-1)**0 * 2**(7-7) * ( 1 + 2**(-3)*(6+0.953125) ) = 957 M part, which is 6+0.953125, will be rounded (half to even) due to the MXFP8_E4M3 quantization. Thus we should have 2**9 * (-1)**0 * 2**(7-7) * ( 1 + 2**(-3)*(7) ) = 960 But the _quantize_mx gives the result of 2**9 * (-1)**0 * 2**(7-7) * ( 1 + 2**(-3)*(6) ) = 896

I also tested the number of 2**9 * (-1)**0 * 2**(7-7) * ( 1 + 2**(-3)*(6+0.1) ) = 902.4 where the M's fractional part will be trauncated. And this time, the _quantize_mx gives the expected result 2**9 * (-1)**0 * 2**(7-7) * ( 1 + 2**(-3)*(6) ) = 896

It seems that the M part of MXFP8_E4M3 cannot be b111. I understand that FP8_E4M3's M cannot be b111 when the exponent are b1111 to keep only one binary representation for the NaN. But, apperently, the number 957 is not the case, the E part are 7 (b0111) not b1111.

Is ther any details I missed for MXFP8_E4M3 format, please help me out of it. THANKS.

zhuango commented 8 months ago

I think I got the missed details, which is the element data should make full use of the element data type. So the largest number in the block is 957 and the shred scale X = 29 / 28 (28 is the largest power-of-two of FP8_E4M3) rather than 29. In that case, E will be 15. Since FP8_E4M3 uses S.1111.111 to represent NaN, the max mantissa part is 110. then all fractional part of M will be cliped/truncated not rounded, resulting in the 896 result.

Close this issue.