tenstorrent / tt-metal

:metal: TT-NN operator library, and TT-Metalium low level kernel programming model.
Apache License 2.0
300 stars 23 forks source link

Falcon7b prefill MLP perf optimizations #9723

Open pavlepopovic opened 4 days ago

pavlepopovic commented 4 days ago

Part of: #8349

Goal is to optimise MLP mms (FF1 & FF2) on 128, 1024, 2048 sequence lengths. The max that we can hit is calculated using this formula, and is the same for FF1 & FF2: (M N K / (64)) * 16 M, N, K - are matmul dimensions (M is seq_length / 32, N is 144, and M is 576) , 64 is the number of cores available on WH, and 16 is a little debatable number representing num_cycles to do a single computation (using LoFi). I'm aware of this number going between 16 and 22 depending on various factors. So, anyway, using 16 cycles, these are the numbers we'll use as our max: 128 sequence length - 83k ns 1024 sequence length - 663k ns 2048 sequence length - 1.327 mil ns

Current state present on main Here is the perf for 3 important sequence lengths: 128 - FF1=518k ns, FF2=448k ns 1024 - FF1=2.64mil ns, FF2=2.38 mil ns 2048 - FF1=5.28 mil ns, FF2= 4.76mil ns

These all matmuls have subblocks (1, 1) (due to di/dt avoidance), use HiFi2, use bfp8 weights, and are 2D block sharded matmuls. When running Falcon prefill model on those sequence lengths, we get output, k and v cache pcc of 0.99 in all variants (single chip, and multi chip)

I've run some experiments in order to try to speed up these matmuls, and here's what I've found (note - 128 seq_len is using a bad variant of matmul, and none of the experiments helped it, it needs fixing asap, 2048k is always x2 of 1024, so when talking about perf I'll focus on 1K only)

Experiments I've ran experiments with upping up subblocks (on N150 machine where we don't see di/dt), reducing math fidelity, and using bfp4 weights

Using proper subblocks When we set the proper subblocks, we immediately see perf gains on matmuls. Here is the current state on main if we set proper subblocks on 1k seq_length: FF1 = 1.62 mil ns, FF2 = 1.50 mil ns (compared to 2.64 and 2.38 with (1, 1) subblocks.

One important thing here: While running various experiments, I've noticed that none of the experiments (fidelities, bfp4 weights), produced any perf gain, as long as the subblocks are (1,1). This is I think because in that case, matmuls are trisc bound, and none of the settings impact the perf of the matmul, so the next experiments that I mention will assume that subblocks are properly set.

Usage of LoFi, instead of HiFi2 If we use Lofi, the PCC of the prefill stays the same on all sequence lengths, and on single and multi chip variants. It is 0.99 in all cases. The perf gets immediately improved: FF1 = 1.33 mil ns, and FF2 = 1.20 mil ns (down from 1.62, 1.50 when using HiFi2)

Usage of BFP4 weights:

When we use BFP4 weights on FF1 and FF2 there is a pcc drop. Output, and v_cache pcc can drop to 0.97, while the the K cache remains 0.99. If we use BFP4 on only one of the mlp matmuls (FF1 or FF2, does not matter which one), the pcc also drops, but to 0.98 in this case. Using bfp4 instead of bfp8 (while HiFi2 was on), did not gain much perf benefits: FF1 = 1.56 mil ns, FF2 = 1.47 mil ns (down from 1.62, and 1.50 in case of bfp8). This is because when HiFi2 is on, MM is compute bound, so reducing the amount of data doesn't seem to help.

BFP4 + Lofi If we now use BFP4 + Lofi, interestingly, the PCC is the same as if we were using BFP4 only (output, and v_cache can drop to 0.97 or 0.98 depending of if we set it on both matmuls or just one). Now, the MMS are no longer compute bound, and usage of BFP4 immediately helps: FF1=1.02 mil ns, FF2 = 0.820 mil ns (down from 1.33 and 1.20 of Bfp8 + HiFi2 case) (@davorchap any ideas how to squeeze this more ^^?) IF we compare these numbers to the theoretical max (663k ns for both MMs), we see that they are pretty close.

Now I was interested to see if that perf is compute or dram/noc bound, so I've gone over kernels and disabled those, and here's what I've found:

Disabling DRAM reads Activations and outputs are shraded on both MMS, so I just needed to disable weight dram reads and use dummy data for that. If that is done, perf is exactly the same as with reading weights from dram, so it isn't dram bound :)

Disabling mcasts in addition of DRAM reads There's two multicasts going on in these matmuls (shard mcast between neighbouring cores, and weight mcast down the column) When those two are disabled, and dummy data is used instead, the perf does get slightly better: FF1=944k ns, FF2=792k ns (compared to 1.02 mil, and 820k)

FF1 & FF2 difference in perf Theoretical value for both of these is the same, but in practice, their perf differs due to the different inner dim and N of the two matmuls). FF1 is (in tiles) [seq_len // 32, 144] [144, 576] FF2 is (in tiles) [seq_len // 32, 576] [576, 144] For FF1, this leads to per_core_M of 4, and per_core_N of 72, which doesn't look balanced, and we can chose max in_block_w of 3 before running out of memory on this one. For FF2, per_core_M is 4 and per_core_N is 18, and we can chose 8 as in0_block_w. In addition to that, FF1 is also running a fused gelu. If we remove that, in the best scenario, FF1 drops from 944k to 870k ns, bringing it closer to 790k of FF1) I remember that on Buda we had option to do GELU on packer thread on WH_B0 instead of on math thread, so we might get rid of gelu difference, don't think that's enabled in metal atm, @ttmtrajkovic @rtawfik01 did that provide any gain, and is it supported on BH as well?

pavlepopovic commented 4 days ago

might be interesting @pavlejosipovic @s-jovic @davorchap

pavlejosipovic commented 4 days ago

Ff1 should be a bit slower due to gelu in case it's compute bound?

pavlepopovic commented 4 days ago

Ff1 should be a bit slower due to gelu in case it's compute bound?

Yes, its mentioned above

pavlejosipovic commented 4 days ago

Well it's seems close to 80% of theoretical number which is best we have seen so far?

pavlepopovic commented 4 days ago

Well it's seems close to 80% of theoretical number which is best we have seen so far?

Yes, provided we are happy with bfp4 pcc, and if decode is happy with it as well, since both need to use same weights

pavlepopovic commented 2 days ago

I've done a couple of more experiments:

I've checked decode PCC's in all tests with bfp4 MLP weights. Here are the results: 1 device, 128 kv length: bfp8 pcc (out, k, v) = [0.89, 0.92, 0.92], bfp4 = [0.89, 0.92, 0.91] // slight decrease 1 device, 128 kv length, l1_sharded: bfp8 pcc (out, k, v) = [0.89, 0.92, 0.92], bfp4 = [0.87, 0.92, 0.91] // slight decrease 1 device, 1024 kv length: bfp8 pcc (out, k, v) = [0.91, 0.94, 0.96], bfp4 = [0.90, 0.92, 0.94] // slight decrease 1 device, 1024 kv length, l1_sharded: bfp8 pcc (out, k, v) = [0.94, 0.95, 0.97], bfp4 = [0.92, 0.95, 0.96] // slight decrease 1 device, 2047 kv length: bfp8 pcc (out, k, v) = [0.96, 0.99, 0.98], bfp4 = [0.94, 0.99, 0.98] // slight decrease 1 device, 2047 kv length, l1_sharded: bfp8 pcc (out, k, v) = [0.92, 0.93, 0.95], bfp4 = [0.93, 0.95, 0.96] // slight increase, weird?

4 devices, 128 kv length, bfp8 pcc (out, k, v) = [0.86, 0.89, 0.90], bfp4 = [0.85, 0.88, 0.88] // slight decrease 4 devices, 1024 kv length, bfp8 pcc (out, k, v) = [0.91, 0.93, 0.93], bfp4 = [0.89, 0.93, 0.94] // out pcc decreases a little, v gets better, weird? 4 devices, 2047 kv length, bfp8 pcc (out, k, v) = [0.93, 0.89, 0.88], bfp4 = [0.91, 0.90, 0.88] // out pcc decreases a little, k gets better, weird? @skhorasganiTT does this PCC seem good to proceed with bfp4 on decode?

Decode MLP perf gets a lot better with just using bfp4, as it's dram bound. FF1 goes down from 457k to 247k, and FF2 goes down from 425k to 298k.

I've also realised that on prefill, usage of bfp4 opens up space on L1, so we can increase in0_block_w on both ff1 (from 3 to 6), and on ff2 (from 8 to 12). In those cases, our best bfp4 perf is the following: FF1 = 889k ns, FF2 = 793k ns (down from 1.02 mil, and 820k with smaller in0_block_w)