tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
377 stars 46 forks source link

LLK for reshuffling tile rows into different locations of a new tile (with reduction) #9817

Closed yan-zaretskiy closed 1 month ago

yan-zaretskiy commented 2 months ago

This request is in the context of embedding backward operation (https://github.com/tenstorrent/tt-metal/issues/6232). In this operation we iterate over tiles of the output gradient (from the forward operation) and propagate it to the weights gradient, by extracting output rows and adding them up at locations specified by the index_ids tensor:

Screenshot 2024-06-28 at 10 49 20

Each core handles a vertical slice of the entire input/output (that is, we split on the hidden size). Then it needs to iterate over every tile of the input, extract each row from it and add it to some other tile in the output. To perform the addition, I need to be able to place it at a different location in a new tile. I'm asking for an LLK to help me do this.

The first idea that came to my mind is very specific to the embedding b/w op. It would accept a tile of output gradient, a tile of index values and a range of indexes. Then it will extract all the tile rows with indices falling within the requested range, and reshuffle them (with summation) according to where they need to be placed in the weight gradient tensor. This is to minimize extracting input rows destined to the same output tile. Here's an illustration, assuming 8x8 tiles for simplicity:

Screenshot 2024-06-26 at 19 01 34

The downside is that it's way too specific and probably useless for anything else. As an alternative, I was hoping to at least get an LLK that would let me extract any row from a tile and put it into a new tile at some other index, so that I can manually add it to the output tile. That could be useful for other ops?

@rtawfik01 @ttmtrajkovic @TT-BrianLiu

davorchap commented 2 months ago

This sounds great. Thanks @yan-zaretskiy

The generalized feature you describe can be useful for reshuffling of rows.

We could provide a new tile row destination index for each row in the original tile?

yan-zaretskiy commented 2 months ago

@davorchap If I'm allowed to dream, then I can think of the following best case scenario:

cglagovichTT commented 2 months ago

It sounds like this sort of row-shuffling support could be useful in update_cache. Update cache needs to take in input tile and a cache tile, batch_idx, seq_idx, and do cache_tile[seq_idx, :] = input_tile[batch_idx, :]

cglagovichTT commented 2 months ago

@yieldthought could you use an op like this for shuffling tokens between experts?

razorback3 commented 2 months ago

FYI. @hschoi4448

ttmtrajkovic commented 2 months ago

adding @rdjogoTT

yieldthought commented 2 months ago

I think we could use this for mixture-of-expert models as we currently implement them (each device has its own expert) - a key step is to perform a weighted sum of the outputs of each expert, reshuffling their output rows back into the right position as specified by the indices from a top-k operation.

TT-BrianLiu commented 1 month ago

Changes to LLK has been merged in: https://github.com/tenstorrent/tt-metal/pull/10495.

The reshuffle API should be usable now but only for wormhole.