Open mikevin920 opened 2 months ago
Which matmul variants and shapes need to fuse AllReduce on the output? All shapes are per-device matmul inputs
matmul_1d and/or matmul dram sharded
[32, 2048] @ [2048, 1280]
[32, 1024] @ [1024, 2048]
[32, 2048] @ [2048, 3584]
[32, 3584] @ [3584, 2048]
matmul_2d - All of these activation shapes may be reshaped to push sequence into batch to fit outputs in L1, like [s, d] -> [b, s/b, d]
[128k, 2048] @ [2048, 1280]
[128k, 1024] @ [1024, 2048]
[128k, 2048] @ [2048, 3584]
[128k, 3584] @ [3584, 2048]
fyi @SeanNijjar
Below are four ways to run MLP for TG llama ranked from most optimized to least.