tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
417 stars 54 forks source link

JIT Build Error with Mixed Data Formats in Output CBs #6945

Open dongjin-na opened 6 months ago

dongjin-na commented 6 months ago

Describe the bug I've discovered that all output CBs (Circular Buffers) used in a kernel must have the same data format. My requirement for implementing a dropout operation includes bfloat16 for the output and uint32 for the random state output tensors. However, using mixed data formats in output CBs seems problematic.

To Reproduce The relevant code examples can be found in the mixed_output_cbs branch of the tt-metal repository. here, test_mixed_data_format_output_cbs.cpp and test_same_data_format_output_cbs.cpp demonstrate the issue.

test_same_data_format_output_cbs.cppworks without issues, creating two output CBs with the same data format (bfloat16).

        uint32_t ouput_cb_index = 16; // output operands start at index 16
        uint32_t output_cb_addr = 300 * 1024;
        uint32_t num_output_tiles = 1;
        tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{ouput_cb_index, tt::DataFormat::Float16_b}})
            .set_page_size(ouput_cb_index, single_tile_size);
        auto cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config);

        ////////////////////////////////////
        // pass
        ////////////////////////////////////
        uint32_t ouput1_cb_index = 17;
        tt_metal::CircularBufferConfig cb_output1_config = tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{ouput1_cb_index, tt::DataFormat::Float16_b}})
            .set_page_size(ouput1_cb_index, single_tile_size);
        auto cb_output1 = tt_metal::CreateCircularBuffer(program, core, cb_output1_config);

test_mixed_data_format_output_cbs.cpp attempts to create two output CBs with different data formats (bfloat16 and uint32).

        uint32_t ouput_cb_index = 16; // output operands start at index 16
        uint32_t output_cb_addr = 300 * 1024;
        uint32_t num_output_tiles = 1;
        tt_metal::CircularBufferConfig cb_output_config = tt_metal::CircularBufferConfig(num_output_tiles * single_tile_size, {{ouput_cb_index, tt::DataFormat::Float16_b}})
            .set_page_size(ouput_cb_index, single_tile_size);
        auto cb_output = tt_metal::CreateCircularBuffer(program, core, cb_output_config);

        ////////////////////////////////////
        // issue point
        ////////////////////////////////////
        uint32_t ouput1_cb_index = 17;
        tt_metal::CircularBufferConfig cb_output1_config = tt_metal::CircularBufferConfig(num_output_tiles * 4096, {{ouput1_cb_index, tt::DataFormat::UInt32}})
            .set_page_size(ouput1_cb_index, 4096);
        auto cb_output1 = tt_metal::CreateCircularBuffer(program, core, cb_output1_config);

This results in a JIT build error with the following message:


$ make ENABLE_PROFILER=1 build
$ make ENABLE_PROFILER=1 tests
$ TT_METAL_SLOW_DISPATCH_MODE=1 ./build/test/tt_metal/test_mixed_data_format_output_cbs
                  Metal | INFO     | AI CLK for device 0 is:   1000 MHz
                  Verif | INFO     | Created a random vector of size 1048576
                 Always | FATAL    | Not all buffer data-formats within this operand are the same
                 Always | FATAL    | Not all buffer data-formats within this operand are the same
                 Always | FATAL    | Not all buffer data-formats within this operand are the same
terminate called after throwing an instance of 'std::runtime_error'
  what():  TT_FATAL @ tt_metal/jit_build/data_format.cpp:105: data_format[i] == last_valid_format
info:
Not all buffer data-formats within this operand are the same
backtrace:
 --- tt::get_pack_data_format(tt::DataFormat*, tt::DataFormat*)
 --- /home/ubuntu/tt-metal/build/lib/libtt_metal.so(+0x363861) [0x7f7c7dc9a861]
 --- /lib/x86_64-linux-gnu/libstdc++.so.6(+0xd6df4) [0x7f7c7d7fbdf4]
 --- /lib/x86_64-linux-gnu/libpthread.so.0(+0x8609) [0x7f7c7d90f609]
 --- /lib/x86_64-linux-gnu/libc.so.6(clone+0x43) [0x7f7c7d5dc353]

Aborted (core dumped)```
razorback3 commented 6 months ago

@jliangTT Would you set the priority of this issue to P0 or P1 depending on the Metal team's capacity? I think this would be the last issue for completing the implement dropout.

jliangTT commented 6 months ago

@davorchap mentioned there is precedent of doing this (not using the same data type for cb buffer). Will provide example so team can compare.

davorchap commented 6 months ago

@ttmtrajkovic can we use PACK(( pack_reconfig_data_format(mm_out_cb_id) )); from bfloat16 to uint32, which is the use case here?

Host is asserting saying we can't support this case.

ttmtrajkovic commented 6 months ago

@davorchap, @razorback3,

There's no problem with reconfig function, the issue was that our infra currently doesn't support more than 1 output (output_ids >= 16) that are with different formats, we didn't have the need for it so far. Changing the cb id of the uint32 data to 24 (the start of intermediate buffer) fixes the problem and test passes.

@razorback3, do you need a second output at id = 17 for this op, or just an intermediate buffer to store some values while the op is working? Could you use cb for uint32 as intermediate (id = 24)? I will follow up on multi-output support with heterogeneous formats.

Milos

dongjin-na commented 6 months ago

@ttmtrajkovic , I will use an intermediate cb = 24 for uint32 as a second output.

@razorback3, I think I can continue implementing dropout op following his guide.

razorback3 commented 6 months ago

@dongjin-na I understand.

@ttmtrajkovic Seems like this issue is no more a blocker for Moreh's development plan but why don't we close this issue after multi-output support with heterogeneous formats is completed? However, I understand this would have lower priority than other issues.