Closed umadevimcw closed 2 months ago
@muthutt Not sure whom to assign this issue. Please update the assignees to the respective team
Great job on pinpointing this issue! @umadevimcw , the excellent debugging made it obvious where to look.
The fix has been merged onto main. (I missed adding your test! ugh).
@eyonland I think its not fixed yet. With the recent main i have re-run the test and still the index value is not matching with the torch value. Please find the values in the red box (in below image)
Terminal output
py::test_arange[300] Metal | INFO | Initializing device 0
Metal | INFO | AI CLK for device 0 is: 1202 MHz
==============================
********** torch ****
tensor([250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263,
264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277,
278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291,
292, 293, 294, 295, 296, 297, 298, 299])
********** TT ****
tensor([250., 251., 252., 253., 254., 255., 256., 256., 258., 258., 260., 260.,
262., 262., 264., 264., 266., 266., 268., 268., 270., 270., 272., 272.,
274., 274., 276., 276., 278., 278., 280., 280., 282., 282., 284., 284.,
286., 286., 288., 288., 290., 290., 292., 292., 294., 294., 296., 296.,
298., 298.], dtype=torch.bfloat16)
Having looked into this, I don't believe this to be a mathematically solvable issue.
The problem is that std::bit_cast<uint32_t>(289.0f) == 0x4390'8000
. The most significant 16 bits of this are 0x4390
if you round down, or 0x4391
if you round up, either of which could be reasonably stored in the bfloat16 format depending on how you choose to round when truncating. These represent 288.0f
and 290.0f
respectively. bfloat16 does not have a representation for the value 289.0f
, no matter how you adjust the heuristics for rounding when truncating or extending. Shifting by 15 bits side-steps the issue by not using the bfloat16 format at all, but rather by dropping the sign bit from the representation and storing the most significant bit of the mantissa that is not normally preserved by the bfloat16 format, which is how it's able to recover the value 289.0f
.
Another way to say this is that 255.0f
is the "max safe integer" in bfloat16 format. Notice how the representation in float32 prevents there from being a valid representation for odd numbers after 256.0f
in bfloat16 format:
static_assert(std::bit_cast<uint32_t>(253.f) == 0x437d'0000);
static_assert(std::bit_cast<uint32_t>(254.f) == 0x437e'0000);
static_assert(std::bit_cast<uint32_t>(255.f) == 0x437f'0000);
static_assert(std::bit_cast<uint32_t>(256.f) == 0x4380'0000);
static_assert(std::bit_cast<uint32_t>(257.f) == 0x4380'8000); // no valid bfloat16
static_assert(std::bit_cast<uint32_t>(258.f) == 0x4381'0000);
static_assert(std::bit_cast<uint32_t>(259.f) == 0x4381'8000); // no valid bfloat16
Just to add some more context here:
import torch
import pytest
import ttnn
@pytest.mark.parametrize("end", (256, 300, 390, 1024))
def test_arange(end, device):
torch.manual_seed(0)
torch_data = torch.arange(250, end, 1, dtype=torch.bfloat16)
tt_output_tensor_on_device = ttnn.arange(250, end, 1, device)
tt_output = tt_output_tensor_on_device.cpu().to(ttnn.Layout.ROW_MAJOR).to_torch()
tt_output = tt_output.flatten()
print("==============================")
torch.set_printoptions(sci_mode=False, threshold=10000)
print("********** torch ****")
print(torch_data)
print("********** TT ****")
print(tt_output)
print(torch.eq(torch_data, tt_output))
gives the output:
********** torch ****
tensor([250., 251., 252., 253., 254., 255., 256., 256., 258., 260., 260., 260.,
262., 264., 264., 264., 266., 268., 268., 268., 270., 272., 272., 272.,
274., 276., 276., 276., 278., 280., 280., 280., 282., 284., 284., 284.,
286., 288., 288., 288., 290., 292., 292., 292., 294., 296., 296., 296.,
298., 300.], dtype=torch.bfloat16)
********** TT ****
tensor([250., 251., 252., 253., 254., 255., 256., 256., 258., 258., 260., 260.,
262., 262., 264., 264., 266., 266., 268., 268., 270., 270., 272., 272.,
274., 274., 276., 276., 278., 278., 280., 280., 282., 282., 284., 284.,
286., 286., 288., 288., 290., 290., 292., 292., 294., 294., 296., 296.,
298., 298.], dtype=torch.bfloat16)
while the rounding is not using the same heuristics in TT as in torch, it does still exhibit the same loss of information due to the bfloat16 format. TT does straight truncation, but torch seems to round up in half the cases after 256.
.
Whether we want to try and exactly match the rounding heuristics is a different question. Torch uses a convention called "round to nearest even"
Describe the bug static_cast return wrong value after ceratin range. This can be observed well in
arange
function. When the arange function is used to generate numbers from 0 to 1024 observed wrong value from the outputTo Reproduce Steps to reproduce the behavior:
umadevimcw/bfloat16_conversion_issue
pytest tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_arange_not_working.py
Expected behavior This code will compare results of torch arange and tt arange and print the boolen tensor
False
represents mismatch value Screenshots Please find the highlighted lines in the first image and second image generated with start value as 250Additional context
functions.hpp
used for arange causes the issue. Specifically the static castowned_buffer[index++] = static_cast<T>(static_cast<float>(value));
In
bfloat16.hpp
there is a conversion module that converts the given value to bfloat16. Replicated the same code separately and observed that the conversion created the wrong value. Please find the same in the attached imagesShift with 16 - Wrong value
Shift with 15 - Correct value
Note : The change with shift 15 , the value stored in owned buffer returns 0, not sure how to update globally.