Open wonjeon opened 3 months ago
well you started here:
and you truncate all but three bits in the mantissa:
which leaves you at 224:
you could try rounding up the mantissa before conversion.
Thanks for your response.
Could you let me know how rounding up can be done before conversion with bitstring
?
BTW, I tried ml_dypes
(https://github.com/jax-ml/ml_dtypes) and notice a different result:
>>> from ml_dtypes import float8_e4m3fn
>>> f8 = float8_e4m3fn(f32)
>>> f8
240
Any opinions?
I'm guessing that you would have to round manually:
import ctypes
# Define a union to represent the float and its equivalent bit representation
class FloatUnion(ctypes.Union):
_fields_ = [("float_value", ctypes.c_float),
("int_value", ctypes.c_uint)]
def round_msb_3_bits_to_nearest_integer(value):
# Create an instance of the union and assign the float value
float_union = FloatUnion()
float_union.float_value = value
# Extract the first 23 bits (fraction part) and the exponent part
fraction_bits = float_union.int_value & 0x7FFFFF # Extracting the fraction (23 bits)
# Extract the most significant 3 bits of the fraction part
msb_3_bits = round(fraction_bits / 0xFFFFF)
# Clear the original MSB 3 bits and set the rounded MSB 3 bits
fraction_bits = (int(msb_3_bits) << 20)
# Combine the exponent and fraction bits back
float_union.int_value = (float_union.int_value & 0xFF80_0000) | fraction_bits
# Return the modified float value
return float_union.float_value
# Test the function
original_value = 232.03683398099045
modified_value = round_msb_3_bits_to_nearest_integer(original_value)
print(f"Original value: {original_value}")
print(f"Modified value: {modified_value}")
Hi @wonjeon, it's nice to see the MXFP formats getting some use!
I think the issue is that there is an implicit conversion from f32 to f16 before the conversion to e4m3mxfp is done. This is done for efficiency reasons - it allows a look-up table to be used from all of the possible f16 numbers which wouldn't be practical from f32. It does mean that edge cases like this can happen.
So effectively we have this rounding happening first:
>>> Bits(float16=232.04).float
232.0
and then the 232.0
gets correctly rounded to 224.0
. A slightly higher value would round to something bigger and then become 240.0
:
>>> Bits(float16=232.07).float
232.125
I wasn't sure if this would be a problem, but I made a note of it in the docs: https://bitstring.readthedocs.io/en/stable/exotic_floats.html#conversion
"Note that for efficiency reasons Python floats are converted to 16-bit IEEE floats before being converted to their final destination. This can mean that in edge cases the rounding to the 16-bit float will cause the next rounding to go in the other direction. The 16-bit float has 11 bits of precision, whereas the final format has at most 4 bits of precision, so this shouldn’t be a real-world problem, but it could cause discrepancies when comparing with other methods. I could add a slower, more accurate mode if this is a problem (add a bug report)."
Would a slower, more accurate mode be useful to you?
Thanks for following up this issue. I think it would be great if bitstring would provide an option for users to select either fast or accurate mode for this case.
Hello. I have the following example to test casting of float number to fp8_e4m3 number:
Why f8 is 224, not 240? The distance from f8 to 240 looks shorter than that to 224. Am I missing anything?