ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

Block sparse mm #1058

Closed jagrit06 closed 2 weeks ago

jagrit06 commented 2 weeks ago

Proposed changes

Adds operation and primitive to gather matrices before matmul on the fly

Checklist

Put an x in the boxes that apply.

awni commented 2 weeks ago

Very cool!! Can't wait to try this in an MOE!

awni commented 2 weeks ago

Some MOE benchmarks:

Generation

python -m mlx_lm.generate --model Qwen/Qwen1.5-MoE-A2.7B-Chat  --prompt "Write a story about Einstein" --max-tokens 256 --temp 0.0
Pre: 31.285 tokens-per-sec
Post: 72.387 tokens-per-sec

LoRA

python -m mlx_lm.lora --train --iters 50 --model Qwen/Qwen1.5-MoE-A2.7B-Chat --data ../lora/data

Pre: Iter 30: Train loss 1.475, Learning Rate 1.000e-05, It/sec 1.692, Tokens/sec 291.262, Trained Tokens 5325, Peak mem 29.248 GB Post: Iter 30: Train loss 1.466, Learning Rate 1.000e-05, It/sec 2.724, Tokens/sec 468.749, Trained Tokens 5325, Peak mem 28.051 GB