ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

Split encoders in non-concurrent context with a max ops per encoder #1085

Closed awni closed 1 week ago

awni commented 1 week ago

Speeds up generation and slightly training, some benchmarks:

Benchmarks on an M2 Ultra

LLM generation

python -m mlx_lm.generate --model mlx-community/NeuralBeagle14-7B-4bit-mlx --prompt "Write a story about Albert Einstein" --temp 0.0 --max-tokens 256

Pre: Prompt: 222.239 tokens-per-sec Generation: 107.239 tokens-per-sec

Post: Prompt: 223.145 tokens-per-sec Generation: 108.463 tokens-per-sec

MNIST

Pre: Test accuracy 0.936, Time 0.632 (s) Post: Test accuracy 0.927, Time 0.625 (s)

LeNet

Pre: Test accuracy 0.981, Time 2.797 (s) Post: Test accuracy 0.975, Time 2.757 (s)

ResNet

Pre: Throughput: 6462.77 samples/second, Peak memory 1.663 (GB) Post: Throughput: 6474.81 samples/second, Peak memory 1.663 (GB)

Transformer training:

python main.py --gpu

Pre: Iter 40: Train loss 7.864, It/sec 5.881, Peak memory 5.534 (GB) Post: Iter 40: Train loss 7.814, It/sec 5.902, Peak memory 5.534 (GB)

awni commented 1 week ago

The generation speed up is pretty nice. Running with a bigger command buffer (100 ops) gives even more speedup:

Generation: 114.676 tokens-per-sec

The main downside is introducing the additional dipatch* methods so that we can track dispatches on the command encoder. I'm not wedded to it..