Closed dongjin-na closed 2 months ago
Faced this issue while implementing Sum
It seems to be the same issue as https://github.com/tenstorrent/tt-metal/issues/10337 The operations add_tile & reduce_tile needs to use srcA and srcB which cannot be fp32 which the preserve_fp32_precision flag directs them to be.
Dear @amahmudTT,
Your previous response suggests that this is the same issue as #10337, but there’s a key difference.
The main difference between this issue and #10337 lies in the input data type. In #10337, an fp32 CB is passed into the reduce_tile API, while in this issue, bfloat16 CBs are used for srcA and srcB.
In #9410, the fp32_dest_acc_en and preserve_fp32_precision flags ensured that the intermediate output maintained 32-bit precision during the pack from DST to fp32 CB, the unpack from fp32 CB to DST, and the DST accumulation.
Similarly, in this case, we want to maintain 32-bit precision results while using 16-bit input data.
I would like to inquire if it is correct that the bit precision available varies with different compute APIs in FP32 mode. I would appreciate it if you could let me know if there are any settings I might have missed.
@dongjin-na, @amahmudTT fyi I did some testing and here are my conclusions.
I rebased the bit_precision
branch onto main and modified the stimulus in test_moreh_sum.py
to be:
-> torch_input = torch.normal(0, 99999, input_shape, dtype=cpu_dtype, requires_grad=True)
rather than
-> torch_input = torch.rand(input_shape, dtype=cpu_dtype, requires_grad=True)
I printed the first 10 fp32 values from the intermediate CB and got the following results:
You can see that for batch and h-dim cases, the new stimulus better demonstrates that all 32 bits of precision can be used depending on the input values. However, for w-dim, the output is in fact limited to 16bits of precision. This is expected due to the implementation of reduce_tile with ReduceDim::REDUCE_ROW, which requires the reloading of Dest values back into SrcB (which cannot support full fp32 precision and instead uses fp16b).
This limitation in accuracy for w-dim was addressed in: https://github.com/tenstorrent/tt-metal/commit/8ce20efa43c21295797897c61487e77074d61029, by using matmul instead. However this fix is currently less efficient than it needs to be, and also only works for reduce sum. I recommend you give this approach a try if you need fp32 accuracy.
Summary:
Let me know if there are any other questions/concerns or if we can close the issue.
Dear @rdjogoTT, I apologize for the delay in my response. I will apply the commit and testing method you mentioned, and then I’ll get back to you with my feedback.
It comes down the how there LLK is implemented right now. For h-dim the LLK just uses the FPU once, and setting the Dest to fp32 mode allows the FPU to produce outputs up to 32bits precision even if the input is fp16b. For w-dim the LLK uses FPU twice, reloading the datums from Dest back into SrcB, which converts them back to fp16b due to the SrcA/B regs not being large enough to support fp32. This is why the alternative method was developed using matmul to avoid this Dest reloading which limits accuracy.
Blackhole uses an identical LLK to wormhole for this OP, so it faces the same limitations.
Dear @rdjogoTT,
Thank you for your detailed explanation.
Based on your suggestions, I conducted tests using torch.normal,
and it appears that both add_tiles
and reduce_tile
operations for h-dim seem to utilize precision close to 32 bits.
To summarize, could you please confirm the following?
reduce_tile
operation for w-dim is limited to 16 bits of precision. Therefore, for sum reductions, it is recommended to use the matmul_tiles
approach to achieve FP32 accuracy.reduce_tile
operation for h-dim and the add_tiles operation both support 32 bits of precision.Additionally, I have one more question:
When unpacking the result of a reduce_tile
API on h-dim, it is necessary to set the preserve_fp32_precision
flag to maintain precision. (This is similar to an issue we discussed previously.)
For example:
The #11756 seems to impose a constraint on the preserve_fp32_precision
flag for reduce_tile
. Do you think this could impact FP32 precision?
Thank you in advance for your insights, and I look forward to your response.
@dongjin-na Regarding your first two points, yes that is correct.
Regarding the additional question, I just want to confirm with you that you want to enable preserve_fp32_precision
for the intermediate buffer, and not the input buffers.
@rdjogoTT, thanks for the check.
I think it’s enough to enable preserve_fp32_precision for the intermediate buffer only. However, I will check and confirm if there is any scenario where the intermediate buffer needs to maintain FP32 precision while using FP32 CB as input (expecting it to be unpacked to BF16) and get back to you.
Can this be closed? Since the only remaining question is regarding when we will permit preserve_fp32_precision, and that will be taken care of in https://github.com/tenstorrent/tt-metal/issues/11756.
I also think so. Let's continue in #11756.
Describe the situation As far as I know, in fp32 dest acc mode (hereafter referred to as FP32 mode), the results of compute API operations are stored in the DST register with 32-bit precision. (This can be confirmed in issue #9410)
The moreh_sum op uses different compute APIs depending on the dimension being reduced:
add_tiles
for batch dimsreduce_tile<SUM, REDUCE_ROW>
for w dimreduce_tile<SUM, REDUCE_COL>
for h dimReferring to issue #9410, I implemented FP32 mode in the moreh_sum op by adding the preserve_fp32_precision flag. (The implemented PR is currently under review.) To verify its correct operation, I checked the intermediate results of each compute API as follows. I will explain with diagrams for better understanding.
Scenario using reduce_tile API: _Figure 1: h and w dim reduction using reducetile API workflow
reduce_tile
compute API using the input CB and scaler CB.Scenario using add_tiles API: _Figure 2: batch dims reduction using addtiles API workflow
add_tiles
compute API using the input CB and an input2 CB filled with zeros.chlkc_pack_data_format.h
chlkc_unpack_data_format.h
I would like to inquire if it is correct that the bit precision available varies with different compute APIs in FP32 mode. I would appreciate it if you could let me know if there are any settings I might have missed.
Also, you can run the test program in the bit_precision branch. To Reproduce Steps to reproduce the behavior:
Please complete the following environment information:
Version of software