ml-explore / mlx

MLX: An array framework for Apple silicon
https://ml-explore.github.io/mlx/
MIT License
14.83k stars 845 forks source link

Improvements in the quantizer and dequantization kernel #1061

Closed angeloskath closed 2 weeks ago

angeloskath commented 2 weeks ago

This PR has two contributions both working together for what is hopefully better quantization performance across the board.

  1. We change the way we compute the scale and bias for the block as follows. a. We set the bias to the min or max value depending on which has the higher absolute value. b. We set the scale to go from min to max or max to min respectively. c. We adjust the scale to make sure that 0 is quantized as 0.
  2. For the dequantization, since the scale is usually a float16, dividing it by 4096 destroys significant information and causes quantization errors. We change it so that it is divided by 16. The same issue does not happen in qmv where everything is float32.

Quantization performance

This is the quantization performance on Wikitext-2 test set. The Q4_0 performance is computed by quantizing and dequantizing the weights in place with absmax quantization and block size 32.

quant

Regarding the block size discussion (which I cannot find now @ivanfioravanti) I think 64 is a good compromise for a default and 32 should be evaluated and used if the 64 performance is not adequate. Wdyt @awni and @jagrit06 ?

quant-blocks

Throughput

The kernel change actually has no performance degradation whatsoever

Before

$ python benchmarks/python/comparative/bench_mlx.py quant_matmul_t_64_4 --size 4096x4096 --size 4096x512 --size 4096x64 --size 4096x64 --dtype float16 --dtype uint32 --dtype float16 --dtype float16
6.557293891906738
$ python -m mlx_lm.lora --model mlx-community/NeuralBeagle14-7B-4bit-mlx --train --data ../../lora/data/
...
...
Iter 1: Val loss 2.866, Val took 8.981s
Iter 10: Train loss 2.323, Learning Rate 1.000e-05, It/sec 1.882, Tokens/sec 752.791, Trained Tokens 3999, Peak mem 6.265 GB
Iter 20: Train loss 1.691, Learning Rate 1.000e-05, It/sec 1.732, Tokens/sec 698.554, Trained Tokens 8032, Peak mem 6.265 GB

After

$ python benchmarks/python/comparative/bench_mlx.py quant_matmul_t_64_4 --size 4096x4096 --size 4096x512 --size 4096x64 --size 4096x64 --dtype float16 --dtype uint32 --dtype float16 --dtype float16
6.5276172161102295
$ python -m mlx_lm.lora --model mlx-community/NeuralBeagle14-7B-4bit-mlx --train --data ../../lora/data/
...
...
Iter 1: Val loss 2.834, Val took 8.946s
Iter 10: Train loss 2.334, Learning Rate 1.000e-05, It/sec 1.880, Tokens/sec 751.839, Trained Tokens 3999, Peak mem 6.265 GB
Iter 20: Train loss 1.699, Learning Rate 1.000e-05, It/sec 1.741, Tokens/sec 702.182, Trained Tokens 8032, Peak mem 6.265 GB