tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
455 stars 67 forks source link

TG Llama3 MLP CCL scheme #11064

Open mikevin920 opened 2 months ago

mikevin920 commented 2 months ago

Below are four ways to run MLP for TG llama ranked from most optimized to least.

  1. ff1 fused with all reduce + ff2 (Requires ttnn.allreduce() to be fused with matmul)
  2. ff1 fused with allgather + reduce op + ff2 (Requires ttnn.allgather() to be fused with matmul)
  3. ff1 + line_reduce_scatter + line_all_gather + ff2 (Requires ttnn.ine_reduce_scatter() to be implemented)
  4. ff1 + line_all_gather + fast local reduction + ff2 (Current Implementation)
cglagovichTT commented 2 months ago

Which matmul variants and shapes need to fuse AllReduce on the output? All shapes are per-device matmul inputs

Decode

matmul_1d and/or matmul dram sharded

  1. fused_qkv: [32, 2048] @ [2048, 1280]
  2. dense: [32, 1024] @ [1024, 2048]
  3. ff1/ff3: [32, 2048] @ [2048, 3584]
  4. ff2: [32, 3584] @ [3584, 2048]

Prefill

matmul_2d - All of these activation shapes may be reshaped to push sequence into batch to fit outputs in L1, like [s, d] -> [b, s/b, d]

  1. fused_qkv: [128k, 2048] @ [2048, 1280]
  2. dense: [128k, 1024] @ [1024, 2048]
  3. ff1/ff3: [128k, 2048] @ [2048, 3584]
  4. ff2: [128k, 3584] @ [3584, 2048]

fyi @SeanNijjar