Open dongjin-na opened 6 months ago
@jliangTT Would you set the priority of this issue to P0 or P1 depending on the Metal team's capacity? I think this would be the last issue for completing the implement dropout.
@davorchap mentioned there is precedent of doing this (not using the same data type for cb buffer). Will provide example so team can compare.
@ttmtrajkovic can we use PACK(( pack_reconfig_data_format(mm_out_cb_id) ));
from bfloat16 to uint32, which is the use case here?
Host is asserting saying we can't support this case.
@davorchap, @razorback3,
There's no problem with reconfig function, the issue was that our infra currently doesn't support more than 1 output (output_ids >= 16) that are with different formats, we didn't have the need for it so far. Changing the cb id of the uint32 data to 24 (the start of intermediate buffer) fixes the problem and test passes.
@razorback3, do you need a second output at id = 17 for this op, or just an intermediate buffer to store some values while the op is working? Could you use cb for uint32 as intermediate (id = 24)? I will follow up on multi-output support with heterogeneous formats.
Milos
@ttmtrajkovic , I will use an intermediate cb = 24 for uint32 as a second output.
@razorback3, I think I can continue implementing dropout op following his guide.
@dongjin-na I understand.
@ttmtrajkovic Seems like this issue is no more a blocker for Moreh's development plan but why don't we close this issue after multi-output support with heterogeneous formats is completed? However, I understand this would have lower priority than other issues.
Describe the bug I've discovered that all output CBs (Circular Buffers) used in a kernel must have the same data format. My requirement for implementing a dropout operation includes bfloat16 for the output and uint32 for the random state output tensors. However, using mixed data formats in output CBs seems problematic.
To Reproduce The relevant code examples can be found in the
mixed_output_cbs
branch of the tt-metal repository. here, test_mixed_data_format_output_cbs.cpp and test_same_data_format_output_cbs.cpp demonstrate the issue.test_same_data_format_output_cbs.cpp
works without issues, creating two output CBs with the same data format (bfloat16).test_mixed_data_format_output_cbs.cpp
attempts to create two output CBs with different data formats (bfloat16 and uint32).This results in a JIT build error with the following message: