tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
463 stars 71 forks source link

Inquiry on add_tiles and reduce_tile Packing Results in FP32 Mode Being Less Than 32-bit #10267

Closed dongjin-na closed 2 months ago

dongjin-na commented 3 months ago

Describe the situation As far as I know, in fp32 dest acc mode (hereafter referred to as FP32 mode), the results of compute API operations are stored in the DST register with 32-bit precision. (This can be confirmed in issue #9410)

The moreh_sum op uses different compute APIs depending on the dimension being reduced:

Referring to issue #9410, I implemented FP32 mode in the moreh_sum op by adding the preserve_fp32_precision flag. (The implemented PR is currently under review.) To verify its correct operation, I checked the intermediate results of each compute API as follows. I will explain with diagrams for better understanding.

Scenario using reduce_tile API: 10267_1 drawio _Figure 1: h and w dim reduction using reducetile API workflow

  1. Read N tiles of input tensor from DRAM into the bfloat16 input CB one tile at a time.
  2. Call the reduce_tile compute API using the input CB and scaler CB.
  3. Store the result in DST register 0. Repeat steps 1, 2, and 3 N times. The result is accumulated in step 3 each time.
  4. Pack DST register 0 into the FP32 intermediate CB and check the bit precision of the elements.

Scenario using add_tiles API: 10267_2 drawio _Figure 2: batch dims reduction using addtiles API workflow

  1. Read N tiles of input tensor from DRAM into the bfloat16 input CB one tile at a time.
  2. Call the add_tiles compute API using the input CB and an input2 CB filled with zeros.
  3. Store the result in DST register 0. Repeat steps 1, 2, and 3 N times. The result is accumulated in step 3 each time.
  4. Pack DST register 0 into the FP32 intermediate CB and check the bit precision of the elements.
I summarized the results in the table below: Sum Direction Compute API Bit Precision value in intermediate CB in chlkc_pack_data_format.h value in intermediate CB in chlkc_unpack_data_format.h
h dim reduce_tile(SUM, REDUCE_COL) 19 0 (FP32) 0
w dim reduce_tile(SUM, REDUCE_ROW) 16 0 0
batch dim add_tiles 19 0 0

I would like to inquire if it is correct that the bit precision available varies with different compute APIs in FP32 mode. I would appreciate it if you could let me know if there are any settings I might have missed.

Also, you can run the test program in the bit_precision branch. To Reproduce Steps to reproduce the behavior:

  1. git checkout origin/bit_precision
  2. CONFIG=Debug ./build_metal.sh
  3. source ./python_env/bin/activate
  4. export TT_METAL_DPRINT_CORES=0,0
  5. pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_sum.py::test_each_dim_in_fp32_mode | grep print
$export TT_METAL_DPRINT_CORES=0,0
$pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_sum.py::test_each_dim_in_fp32_mode | grep print

  Detecting chips (found 2)                                                     

(batch-dim) print the bit precision of the first element 13.7109 of the FP32 CB: 0100000101011011 0110000000000000
2024-07-17 11:44:50.731 | DEBUG    | tests.tt_eager.python_api_testing.unit_testing.misc.test_moreh_sum:test_each_dim_in_fp32_mode:369 - MAE=0.025203704833984375

(h-dim) print the bit precision of the first element 374.75 of the FP32 CB: 0100001110111011 0110000000000000
2024-07-17 11:44:51.442 | DEBUG    | tests.tt_eager.python_api_testing.unit_testing.misc.test_moreh_sum:test_each_dim_in_fp32_mode:369 - MAE=0.8543701171875

(w-dim) print the bit precision of the first element 368 of the FP32 CB: 0100001110111000 0000000000000000
2024-07-17 11:44:52.179 | DEBUG    | tests.tt_eager.python_api_testing.unit_testing.misc.test_moreh_sum:test_each_dim_in_fp32_mode:369 - MAE=13.4622802734375

Please complete the following environment information:

razorback3 commented 3 months ago

Faced this issue while implementing Sum

amahmudTT commented 2 months ago

It seems to be the same issue as https://github.com/tenstorrent/tt-metal/issues/10337 The operations add_tile & reduce_tile needs to use srcA and srcB which cannot be fp32 which the preserve_fp32_precision flag directs them to be.

dongjin-na commented 2 months ago

Dear @amahmudTT,

Your previous response suggests that this is the same issue as #10337, but there’s a key difference.

The main difference between this issue and #10337 lies in the input data type. In #10337, an fp32 CB is passed into the reduce_tile API, while in this issue, bfloat16 CBs are used for srcA and srcB.

In #9410, the fp32_dest_acc_en and preserve_fp32_precision flags ensured that the intermediate output maintained 32-bit precision during the pack from DST to fp32 CB, the unpack from fp32 CB to DST, and the DST accumulation.

Similarly, in this case, we want to maintain 32-bit precision results while using 16-bit input data.

I would like to inquire if it is correct that the bit precision available varies with different compute APIs in FP32 mode. I would appreciate it if you could let me know if there are any settings I might have missed.

rdjogoTT commented 2 months ago

@dongjin-na, @amahmudTT fyi I did some testing and here are my conclusions.

I rebased the bit_precision branch onto main and modified the stimulus in test_moreh_sum.py to be: -> torch_input = torch.normal(0, 99999, input_shape, dtype=cpu_dtype, requires_grad=True) rather than -> torch_input = torch.rand(input_shape, dtype=cpu_dtype, requires_grad=True)

I printed the first 10 fp32 values from the intermediate CB and got the following results: screenshot_2024-08-27_at_5 45 32___pm_720

You can see that for batch and h-dim cases, the new stimulus better demonstrates that all 32 bits of precision can be used depending on the input values. However, for w-dim, the output is in fact limited to 16bits of precision. This is expected due to the implementation of reduce_tile with ReduceDim::REDUCE_ROW, which requires the reloading of Dest values back into SrcB (which cannot support full fp32 precision and instead uses fp16b).

This limitation in accuracy for w-dim was addressed in: https://github.com/tenstorrent/tt-metal/commit/8ce20efa43c21295797897c61487e77074d61029, by using matmul instead. However this fix is currently less efficient than it needs to be, and also only works for reduce sum. I recommend you give this approach a try if you need fp32 accuracy.

Summary:

Let me know if there are any other questions/concerns or if we can close the issue.

dongjin-na commented 2 months ago

Dear @rdjogoTT, I apologize for the delay in my response. I will apply the commit and testing method you mentioned, and then I’ll get back to you with my feedback.

dongjin-na commented 2 months ago
rdjogoTT commented 2 months ago

It comes down the how there LLK is implemented right now. For h-dim the LLK just uses the FPU once, and setting the Dest to fp32 mode allows the FPU to produce outputs up to 32bits precision even if the input is fp16b. For w-dim the LLK uses FPU twice, reloading the datums from Dest back into SrcB, which converts them back to fp16b due to the SrcA/B regs not being large enough to support fp32. This is why the alternative method was developed using matmul to avoid this Dest reloading which limits accuracy.

Blackhole uses an identical LLK to wormhole for this OP, so it faces the same limitations.

dongjin-na commented 2 months ago

Dear @rdjogoTT,

Thank you for your detailed explanation. Based on your suggestions, I conducted tests using torch.normal, and it appears that both add_tiles and reduce_tile operations for h-dim seem to utilize precision close to 32 bits.

To summarize, could you please confirm the following?

  1. The reduce_tile operation for w-dim is limited to 16 bits of precision. Therefore, for sum reductions, it is recommended to use the matmul_tiles approach to achieve FP32 accuracy.
  2. The reduce_tile operation for h-dim and the add_tiles operation both support 32 bits of precision.

Additionally, I have one more question:

When unpacking the result of a reduce_tile API on h-dim, it is necessary to set the preserve_fp32_precision flag to maintain precision. (This is similar to an issue we discussed previously.) For example:

  1. Reduce for h-dim to FP32 CB: 0100010010000001 0100001110100000
  2. Unpack with preserve_fp32_precision=false
  3. Pack to FP32 CB: 0100010010000001 0100000000000000 (truncated)

The #11756 seems to impose a constraint on the preserve_fp32_precision flag for reduce_tile. Do you think this could impact FP32 precision?

Thank you in advance for your insights, and I look forward to your response.

rdjogoTT commented 2 months ago

@dongjin-na Regarding your first two points, yes that is correct.

Regarding the additional question, I just want to confirm with you that you want to enable preserve_fp32_precision for the intermediate buffer, and not the input buffers.

dongjin-na commented 2 months ago

@rdjogoTT, thanks for the check.

I think it’s enough to enable preserve_fp32_precision for the intermediate buffer only. However, I will check and confirm if there is any scenario where the intermediate buffer needs to maintain FP32 precision while using FP32 CB as input (expecting it to be unpacked to BF16) and get back to you.

rdjogoTT commented 2 months ago

Can this be closed? Since the only remaining question is regarding when we will permit preserve_fp32_precision, and that will be taken care of in https://github.com/tenstorrent/tt-metal/issues/11756.

dongjin-na commented 2 months ago

I also think so. Let's continue in #11756.