tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
420 stars 54 forks source link

Static cast <bfloat16> returns wrong value after certain range #4787

Closed umadevimcw closed 2 months ago

umadevimcw commented 8 months ago

Describe the bug static_cast return wrong value after ceratin range. This can be observed well in arange function. When the arange function is used to generate numbers from 0 to 1024 observed wrong value from the output

To Reproduce Steps to reproduce the behavior:

  1. Go to branch umadevimcw/bfloat16_conversion_issue
  2. Run pytest tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_arange_not_working.py

Expected behavior This code will compare results of torch arange and tt arange and print the boolen tensor False represents mismatch value Screenshots Please find the highlighted lines in the first image and second image generated with start value as 250

Screenshot 2024-01-18 at 7 56 02 PM Screenshot 2024-01-18 at 8 01 37 PM

Additional context

In bfloat16.hpp there is a conversion module that converts the given value to bfloat16. Replicated the same code separately and observed that the conversion created the wrong value. Please find the same in the attached images

Shift with 16 - Wrong value

Screenshot 2024-01-18 at 8 25 22 PM

Shift with 15 - Correct value

Screenshot 2024-01-18 at 8 25 52 PM

Note : The change with shift 15 , the value stored in owned buffer returns 0, not sure how to update globally.

umadevimcw commented 8 months ago

@muthutt Not sure whom to assign this issue. Please update the assignees to the respective team

eyonland commented 7 months ago

Great job on pinpointing this issue! @umadevimcw , the excellent debugging made it obvious where to look.

The fix has been merged onto main. (I missed adding your test! ugh).

umadevimcw commented 7 months ago

@eyonland I think its not fixed yet. With the recent main i have re-run the test and still the index value is not matching with the torch value. Please find the values in the red box (in below image)

image

Terminal output

py::test_arange[300]                   Metal | INFO     | Initializing device 0
                  Metal | INFO     | AI CLK for device 0 is:   1202 MHz
==============================
********** torch  ****
tensor([250, 251, 252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263,
        264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, 275, 276, 277,
        278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291,
        292, 293, 294, 295, 296, 297, 298, 299])
********** TT  ****
tensor([250., 251., 252., 253., 254., 255., 256., 256., 258., 258., 260., 260.,
        262., 262., 264., 264., 266., 266., 268., 268., 270., 270., 272., 272.,
        274., 274., 276., 276., 278., 278., 280., 280., 282., 282., 284., 284.,
        286., 286., 288., 288., 290., 290., 292., 292., 294., 294., 296., 296.,
        298., 298.], dtype=torch.bfloat16)
patrickroberts commented 2 months ago

Having looked into this, I don't believe this to be a mathematically solvable issue.

The problem is that std::bit_cast<uint32_t>(289.0f) == 0x4390'8000. The most significant 16 bits of this are 0x4390 if you round down, or 0x4391 if you round up, either of which could be reasonably stored in the bfloat16 format depending on how you choose to round when truncating. These represent 288.0f and 290.0f respectively. bfloat16 does not have a representation for the value 289.0f, no matter how you adjust the heuristics for rounding when truncating or extending. Shifting by 15 bits side-steps the issue by not using the bfloat16 format at all, but rather by dropping the sign bit from the representation and storing the most significant bit of the mantissa that is not normally preserved by the bfloat16 format, which is how it's able to recover the value 289.0f. IEEE 754 single-precision 32-bit float bfloat16

patrickroberts commented 2 months ago

Another way to say this is that 255.0f is the "max safe integer" in bfloat16 format. Notice how the representation in float32 prevents there from being a valid representation for odd numbers after 256.0f in bfloat16 format:

static_assert(std::bit_cast<uint32_t>(253.f) == 0x437d'0000);
static_assert(std::bit_cast<uint32_t>(254.f) == 0x437e'0000);
static_assert(std::bit_cast<uint32_t>(255.f) == 0x437f'0000);
static_assert(std::bit_cast<uint32_t>(256.f) == 0x4380'0000);
static_assert(std::bit_cast<uint32_t>(257.f) == 0x4380'8000); // no valid bfloat16
static_assert(std::bit_cast<uint32_t>(258.f) == 0x4381'0000);
static_assert(std::bit_cast<uint32_t>(259.f) == 0x4381'8000); // no valid bfloat16
patrickroberts commented 2 months ago

Just to add some more context here:

import torch
import pytest
import ttnn

@pytest.mark.parametrize("end", (256, 300, 390, 1024))
def test_arange(end, device):
    torch.manual_seed(0)
    torch_data = torch.arange(250, end, 1, dtype=torch.bfloat16)

    tt_output_tensor_on_device = ttnn.arange(250, end, 1, device)

    tt_output = tt_output_tensor_on_device.cpu().to(ttnn.Layout.ROW_MAJOR).to_torch()

    tt_output = tt_output.flatten()
    print("==============================")
    torch.set_printoptions(sci_mode=False, threshold=10000)

    print("********** torch  ****")
    print(torch_data)

    print("********** TT  ****")
    print(tt_output)

    print(torch.eq(torch_data, tt_output))

gives the output:

********** torch  ****
tensor([250., 251., 252., 253., 254., 255., 256., 256., 258., 260., 260., 260.,
        262., 264., 264., 264., 266., 268., 268., 268., 270., 272., 272., 272.,
        274., 276., 276., 276., 278., 280., 280., 280., 282., 284., 284., 284.,
        286., 288., 288., 288., 290., 292., 292., 292., 294., 296., 296., 296.,
        298., 300.], dtype=torch.bfloat16)
********** TT  ****
tensor([250., 251., 252., 253., 254., 255., 256., 256., 258., 258., 260., 260.,
        262., 262., 264., 264., 266., 266., 268., 268., 270., 270., 272., 272.,
        274., 274., 276., 276., 278., 278., 280., 280., 282., 282., 284., 284.,
        286., 286., 288., 288., 290., 290., 292., 292., 294., 294., 296., 296.,
        298., 298.], dtype=torch.bfloat16)

while the rounding is not using the same heuristics in TT as in torch, it does still exhibit the same loss of information due to the bfloat16 format. TT does straight truncation, but torch seems to round up in half the cases after 256..

Whether we want to try and exactly match the rounding heuristics is a different question. Torch uses a convention called "round to nearest even"