tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
501 stars 83 forks source link

[Feature Request] Let `reduce_tile`, `add_tiles`. `matmul_tiles` and `mul_tiles` take registers as input. #7767

Open marty1885 opened 7 months ago

marty1885 commented 7 months ago

Is your feature request related to a problem? Please describe.

I'm working on a GEMV implementation for LLM inference. Since the SFPU does not support GEMV natively. I had to make my own from tile multiplication and reduction. Currently it looks like this.

#define REDUCE_OP PoolType::SUM
#define REDUCE_DIM ReduceDim::REDUCE_SCALAR

// setup code.....
// cb_ones is a tile that is full of 1s

constexpr uint32_t dst_reg = 0;
acquire_dst(tt::DstMode::Full);
cb_wait_front(cb_in0, 1);
cb_wait_front(cb_in1, 1);
mul_tiles_init();
mul_tiles(cb_in0, cb_in1, 0, 0, dst_reg);
cb_reserve_back(cb_intermid0, 1);
pack_tile(dst_reg, cb_intermid0); // Packs tile to cb_intermid0 ---+
cb_push_back(cb_intermid, 1);                                   // |
release_dst(tt::DstMode::Full);                                 // |
cb_pop_front(cb_in0, 1);                                        // |
cb_pop_front(cb_in1, 1);                                        // |
                                                                // |
acquire_dst(tt::DstMode::Full);                                 // |
// Then we immidatelly consumers cb_intermid0 vvvvvvv--------------+
reduce_init<true>(REDUCE_OP, REDUCE_DIM, cb_intermid0, cb_ones);
reduce_tile(REDUCE_OP, REDUCE_DIM, cb_intermid0, cb_ones, 0, 0, cb_out0);
pack_tile(0, cb_out0);
release_dst(tt::DstMode::Full);
cb_push_back(cb_out0, 1);
cb_pop_front(cb_intermid0, 1);

I suppose this is not efficient as data has to travel between the SFPU and the L1 memory. Likewise for adding tiles. This would be helpful in chaining adding bias after matrix multiplications.

Describe the solution you'd like

I'd like API that allows me to directly reduce from the output of mul_tiles. Something like the following

#define REDUCE_OP PoolType::SUM
#define REDUCE_DIM ReduceDim::REDUCE_SCALAR

// setup code.....
// cb_ones is a tile that is full of 1s

constexpr uint32_t intermid_reg = 0;
constexpr uint32_t dst_reg = 1;
tile_regs_acquire()
cb_wait_front(cb_in0, 1);
cb_wait_front(cb_in1, 1);
mul_tiles_init();
mul_tiles(cb_in0, cb_in1, 0, 0, intermid_reg);// +
cb_pop_front(cb_in0, 1);                      // |
cb_pop_front(cb_in1, 1);                      // |
// Directly chain the output from multiplcation into reduce
//                                         vvvvvv+
reduce_init<true>(REDUCE_OP, REDUCE_DIM, intermid_reg, cb_ones);
reduce_tile(REDUCE_OP, REDUCE_DIM, cb_intermid0, cb_ones, 0, 0, cb_out0);
tile_regs_commit();
tile_regs_wait();
tile_regs_release();
pack_tile(0, cb_out0);
cb_push_back(cb_out0, 1);

Describe alternatives you've considered A clear and concise description of any alternative solutions or features you've considered.

Additional context Add any other context or screenshots about the feature request here.

jliangTT commented 7 months ago

@davorchap, @jvasilje. This user came from the commumity. This seems architectural in nature, any feedback?

jliangTT commented 7 months ago

assigning to myself while triaging.

davorchap commented 7 months ago

looping in @ttmtrajkovic

marty1885 commented 7 months ago

matmul in GS can efficiently handle b=4 (4 rows in a tile)

Is there an API to get GS to do batch=4? It feels like much better then what I can get from the matmul API (1/32 utilization) or multiply then reduce (also very low utilization).

ttmtrajkovic commented 7 months ago

Support to load from DST into SRC registers could be added which saves time in spilling and loading into intermediate buffer. However, move from DST to SRC registers could be done on a tile or, in case of Grayskull, sub-tile level as SRC registers are not as big as DST. This API doesn't exist yet in tt-metal and it's not at the priority to add it.

  • GEMV: using a native matmul in the FPU may likely be faster (even at lower utilization) than doing mul -> reduce chain. matmul in GS can efficiently handle b=4 (4 rows in a tile) matmul in WH ca efficiently handle b=8 (8 rows in a tile) Milos can provide more insight.

Using SFPU to do any type of matrix operation is slow and for GEMV, it would be more efficient to just pad the vector to 32x32 tile and use matmul_tiles API. GS / WH can handle less than 32x32, GS in a cycle does 4x16 multiplied by 16x16 while WH does 8x16 multiplied by 16x16, but the bottleneck will be in that case on moving data to FPU since there's no data reuse so in the end efficiency will be low. However, it would still be better than any other matrix-vector multiplication we could build. There's currently no support or plan to have operands at less than 32x32.

jliangTT commented 7 months ago

based on the current comment, will assign it as P3 and can bump if the discussion progresses.