ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
17.37k stars 1.01k forks source link

2-Pass Sdpa Inference Kernel #1597

Closed angeloskath closed 3 days ago

angeloskath commented 4 days ago

This PR aims to improve long context generation performance by increasing parallelization for large numbers of keys/values. There are mild benefits for smaller machines and very significant benefits for Ultra machines.

The main benefit for small machines stems from accessing the keys and values in a more cache friendly way when there is GQA and for the Ultra machines it stems from launching more thread groups which allows using more of the chip.

Speedup for M2 Max

The following speedup is in total tokens per second and not attention speedup. Note the phi model which does not improve does not have GQA. The 1 pass SDPA on the M2 Max achieves ~350 to 380 GB/s read for sequence length ~2048 so there isn't really much room for speedup.

m2-max-speedup

Speedup for M2 Ultra

Again the speedup is in total tokens per second and not attention specific. The M2 Ultra is sped up for all cases, no GQA required. The 2048 sequence length without GQA peaks at >800GB/s which also means there is probably little room for improvement (there could be for longer sequences).

m2-ultra-speedup