tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
417 stars 54 forks source link

HiFi2 and HiFi4 give different results for bfloat16 @ bfloat16 matmul #12453

Open yieldthought opened 3 weeks ago

yieldthought commented 3 weeks ago

Describe the bug HiFi2 and HiFi4 give different results for bfloat16 @ bfloat16 matmul on wormhole

To Reproduce On the branch yieldthought/12453: pytest tests/ttnn/unit_tests/operations/test_matmul.py::test_matmul_fidelity

This does a 32x1024x1024 matmul twice, once with HiFi2 and once with HiFi4 and prints out the max abs error. Suspiciously despite having random inputs, the max absolute error between HiFi2 and HiFi4 outputs is exactly 1.00000.

Expected behavior I expect HiFi2 and HiFi4 to have bit-for-bit identical results as bfloat16 has a 7-bit mantissa and our multipliers are 7-bit and 5-bit, so 2 passes of the 5-bit input should be sufficient.

yieldthought commented 3 weeks ago
MxKxN Max abs error
32x512x512 0.000000
32x1024x1024 1.000000
32x2048x2048 1.000000
32x4096x4096 2.000000

🤷

yieldthought commented 3 weeks ago

If I set fp32_dest_acc_en=False for both kernels, I get an MAE of 1.500000 on 32x1024x1024 🤔

ncvetkovicTT commented 2 weeks ago

@yieldthought Hey, thanks for raising the issue and taking the time to explain the behavior. So in short what happens is the following:

The HW multipliers are 5x7 wide for WH/BH, but in the binary representation of a decimal number there's that implied '1' before the decimal point which means that the multiplier is effectively 4x6 if we consider mantissa bits only. This means that HiFi2 and HiFi4 are the same as long as the second operand has '0' at its least significant bit place, meaning that it is represented as:

Whole part Decimal_0 Decimal_1 Decimal_2 Decimal_3 Decimal_4 Decimal_5 Decimal_6
b . b b b b b b 0

Beware: for GS the multiplier is 5x5 😅

You can find out more here where I also mention the branch with the examples: ncvetkovic/12453