tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
https://docs.tenstorrent.com/ttnn/latest/index.html
Apache License 2.0
484 stars 79 forks source link

[Llama2] Perf burndown #7408

Open cglagovichTT opened 7 months ago

cglagovichTT commented 7 months ago

This issue tracks the open issues the model team must solve in order to hit Llama2 perf targets.

Decode 128

We have a new perf target of 20 tok/s at seqlen = 128. This issue lists the problems we will have to overcome to hit that target.

milestone single layer latency (ms) full model latency (ms) t/s/u
current performance 0.753 60.24 16.6
target @ 100% 0.625 50 20.0
projected after burndown 0.623 49.84 20.06

Issues

Prefill 2k

We focus on 2k prefill for the same reason we focus on 2k decode. It is the tougher regime and needs more attention than smaller seqlen prefills.

Issues:

milestone single layer latency (ms) full model latency (ms) t/s
current performance 20.16 1612.8 1270.5
target @50% 6.84 547 3745
projected after burndown 7.24 579.2 3537.1
cglagovichTT commented 7 months ago

fyi @uaydonat @davorchap

uaydonat commented 7 months ago

We discussed this in our team sync today.

On the majority of the tasks, we are blocked by the metal team support.

We will try to take on the following tasks as time/priorities permit:

cglagovichTT commented 7 months ago

From discussions with @SeanNijjar: in order to enable him to overlap AllGather and matmul, it will be the model team's responsibility to hoist LayerNorm above AllGather. This will involve implementing a distributed LayerNorm, in which each chip computes statistics locally, we communicate statistics between all chips, and then each chip executes the rest of LayerNorm.

cglagovichTT commented 7 months ago

decode 128 latest perf https://docs.google.com/spreadsheets/d/1I2uxNvaQxQHp6niyFmXvXYbqZgUyfEWcUzPE-TOBUEU/edit?usp=sharing

cglagovichTT commented 7 months ago
image
cglagovichTT commented 7 months ago

This is the projected perf after burndown at 800MHz.

image
cglagovichTT commented 7 months ago

We profiled FF2 with reduce-scatter shapes (32 x 4k @ 4k x 8k) and the perf is about the same, so that doesn't give us free speedup.

cglagovichTT commented 7 months ago

Similarly, fusing FF1 and FF3 did not lead to gains. Separately, they are 120 and 111 us. Together, they are 237 us.

cglagovichTT commented 6 months ago

I was able to get LayerNorm from 24 us to 18.5 us by increasing granularity of core-to-core reads. I don't see a path forward for much more optimization, but I'll ask the original writer.

cglagovichTT commented 6 months ago

Decode TM plan of record:

We are doing this because update_cache will parallelize over the T dim. If we put batch in T, then we can parallelize update_cache onto 32 cores, which Mixtral team has seen gives 66% speedup, saving 40 us. In addition, we should save some time on the pad -> transpose and transpose -> unpad combinations by folding them into the DM ops. I estimate we could save another ~30 us from this.

cglagovichTT commented 6 months ago

Decode perf measured at 1GHz. 200us improvement over 800MHz! Can be found in the same spreadsheet

image
cglagovichTT commented 6 months ago

Note that the large matmuls are running at 171 to 183 GB/s. It appears that the large AllGather is running in 8.26 GB/s (multiply the in0BW column by 7 for AllGathers to get the true number). The small AllGathers attain even lower link utilizations. AllGather did not benefit from 1GHz, so the estimates for potential gains from bidir 2-link still hold. In fact, those estimates are very pessimistic since the current unidir 1-link AllGather is operating at 30% link util anyways.

cglagovichTT commented 5 months ago

On xuncai/llama-perf, we have the newest decode perf.

image
cglagovichTT commented 5 months ago

Latest decode perf: 752us per layer

image

https://docs.google.com/spreadsheets/d/1I2uxNvaQxQHp6niyFmXvXYbqZgUyfEWcUzPE-TOBUEU/edit#gid=2062002106&range=A1