Open cglagovichTT opened 7 months ago
fyi @uaydonat @davorchap
We discussed this in our team sync today.
On the majority of the tasks, we are blocked by the metal team support.
We will try to take on the following tasks as time/priorities permit:
From discussions with @SeanNijjar: in order to enable him to overlap AllGather and matmul, it will be the model team's responsibility to hoist LayerNorm above AllGather. This will involve implementing a distributed LayerNorm, in which each chip computes statistics locally, we communicate statistics between all chips, and then each chip executes the rest of LayerNorm.
This is the projected perf after burndown at 800MHz.
We profiled FF2 with reduce-scatter shapes (32 x 4k @ 4k x 8k) and the perf is about the same, so that doesn't give us free speedup.
Similarly, fusing FF1 and FF3 did not lead to gains. Separately, they are 120 and 111 us. Together, they are 237 us.
I was able to get LayerNorm from 24 us to 18.5 us by increasing granularity of core-to-core reads. I don't see a path forward for much more optimization, but I'll ask the original writer.
Decode TM plan of record:
tt::tt_metal::NlpCreateHeads
such that it outputs Q, K, and V padded and transposed y-z, sharded on 32 corestt::tt_metal::NlpConcatHeads
such that it takes transposed, padded inputWe are doing this because update_cache
will parallelize over the T dim. If we put batch in T, then we can parallelize update_cache
onto 32 cores, which Mixtral team has seen gives 66% speedup, saving 40 us.
In addition, we should save some time on the pad -> transpose
and transpose -> unpad
combinations by folding them into the DM ops. I estimate we could save another ~30 us from this.
Decode perf measured at 1GHz. 200us improvement over 800MHz! Can be found in the same spreadsheet
Note that the large matmuls are running at 171 to 183 GB/s.
It appears that the large AllGather is running in 8.26 GB/s
(multiply the in0BW column by 7 for AllGathers to get the true number). The small AllGathers attain even lower link utilizations. AllGather did not benefit from 1GHz, so the estimates for potential gains from bidir 2-link still hold. In fact, those estimates are very pessimistic since the current unidir 1-link AllGather is operating at 30% link util anyways.
On xuncai/llama-perf
, we have the newest decode perf.
Latest decode perf: 752us per layer
This issue tracks the open issues the model team must solve in order to hit Llama2 perf targets.
Decode 128
We have a new perf target of 20 tok/s at seqlen = 128. This issue lists the problems we will have to overcome to hit that target.
Issues
create_nlp_heads
(Kevin/Jack)concat_nlp_heads
(Kevin/Jack)Prefill 2k
We focus on 2k prefill for the same reason we focus on 2k decode. It is the tougher regime and needs more attention than smaller seqlen prefills.
Issues: