scott-griffiths / bitstring

A Python module to help you manage your bits
https://bitstring.readthedocs.io/en/stable/index.html
MIT License
412 stars 68 forks source link

Issue of casting f32 to e4m3mxfp #342

Open wonjeon opened 3 months ago

wonjeon commented 3 months ago

Hello. I have the following example to test casting of float number to fp8_e4m3 number:

>>> from bitstring import Bits
>>> f32 = 232.03683398099045
>>> f8 = Bits(e4m3mxfp=f32).e4m3mxfp
>>> f8
224.0
>>> abs(f32-240)
7.963166019009549
>>> abs(f32-224)
8.036833980990451

Why f8 is 224, not 240? The distance from f8 to 240 looks shorter than that to 224. Am I missing anything?

taylorh140 commented 3 months ago

well you started here: image

and you truncate all but three bits in the mantissa: image

which leaves you at 224: image

you could try rounding up the mantissa before conversion.

wonjeon commented 3 months ago

Thanks for your response. Could you let me know how rounding up can be done before conversion with bitstring?

BTW, I tried ml_dypes (https://github.com/jax-ml/ml_dtypes) and notice a different result:

>>> from ml_dtypes import float8_e4m3fn
>>> f8 = float8_e4m3fn(f32)
>>> f8
240

Any opinions?

taylorh140 commented 3 months ago

I'm guessing that you would have to round manually:

import ctypes

# Define a union to represent the float and its equivalent bit representation
class FloatUnion(ctypes.Union):
    _fields_ = [("float_value", ctypes.c_float),
                ("int_value", ctypes.c_uint)]

def round_msb_3_bits_to_nearest_integer(value):
    # Create an instance of the union and assign the float value
    float_union = FloatUnion()
    float_union.float_value = value

    # Extract the first 23 bits (fraction part) and the exponent part
    fraction_bits = float_union.int_value & 0x7FFFFF  # Extracting the fraction (23 bits)

    # Extract the most significant 3 bits of the fraction part
    msb_3_bits = round(fraction_bits / 0xFFFFF)

    # Clear the original MSB 3 bits and set the rounded MSB 3 bits
    fraction_bits = (int(msb_3_bits) << 20)

    # Combine the exponent and fraction bits back
    float_union.int_value = (float_union.int_value & 0xFF80_0000) | fraction_bits

    # Return the modified float value
    return float_union.float_value

# Test the function
original_value = 232.03683398099045
modified_value = round_msb_3_bits_to_nearest_integer(original_value)

print(f"Original value: {original_value}")
print(f"Modified value: {modified_value}")
scott-griffiths commented 3 months ago

Hi @wonjeon, it's nice to see the MXFP formats getting some use!

I think the issue is that there is an implicit conversion from f32 to f16 before the conversion to e4m3mxfp is done. This is done for efficiency reasons - it allows a look-up table to be used from all of the possible f16 numbers which wouldn't be practical from f32. It does mean that edge cases like this can happen.

So effectively we have this rounding happening first:

>>> Bits(float16=232.04).float
232.0

and then the 232.0 gets correctly rounded to 224.0. A slightly higher value would round to something bigger and then become 240.0:

>>> Bits(float16=232.07).float
232.125

I wasn't sure if this would be a problem, but I made a note of it in the docs: https://bitstring.readthedocs.io/en/stable/exotic_floats.html#conversion

"Note that for efficiency reasons Python floats are converted to 16-bit IEEE floats before being converted to their final destination. This can mean that in edge cases the rounding to the 16-bit float will cause the next rounding to go in the other direction. The 16-bit float has 11 bits of precision, whereas the final format has at most 4 bits of precision, so this shouldn’t be a real-world problem, but it could cause discrepancies when comparing with other methods. I could add a slower, more accurate mode if this is a problem (add a bug report)."

Would a slower, more accurate mode be useful to you?

wonjeon commented 3 months ago

Thanks for following up this issue. I think it would be great if bitstring would provide an option for users to select either fast or accurate mode for this case.