Closed arpan-dhatt closed 2 months ago
Thanks that's good to have. For QMM we have much bigger low hanging fruit that we should explore first (the dequantization needs to be updated like in qmv/qvm). Accumulating in float16 is risky and as you mentioned it may be better to change it to accumulate in float32 but multiply in float16, we 'll see.
For now, I would not PR that but if we do decide to go that way in the future I 'll let you know and you can PR it then.
I will close this as accumulating in float16 is not something we want to do for numerical reasons. Feel free to reopen if I have missed something.
Question regarding QMM, QMV, and QVM kernels. When using the Metal Debugger on them I noticed that no matter the data type used for the operation itself (e.g. float16) all the simdgroup_matrix operations happend in fp32. The fp16 utilization effectively 0.
I went ahead and made the very trivial changes to the quantized.metal file to use fp16 for multiplication and accumulation (set type template param for BlockMMA to half in qmm_t and misc changes to static_cast expressions using float constants that wouldn't do it implicitly). For QMV and QVM kernels I got basically no speedup (memory bound for sure) but for QMM I'm getting roughly a 5-6% on M2 and M1 Max.
I added a benchmark to benchmarks/cpp/single_ops.cpp to check and those were the results I'm getting. I also tried just looking at the error between the QMM/QMV/QVM and un-quantized versions and they were negligible, compiled it and used a language model with it mlx-community/Starling-LM-7B-beta and outputs were also great and prompt processing speed increased.
I haven't run any vigorous perplexity measurements to make sure there isn't any problem for sure though. I can make a PR for this stuff, but I'm a bit hesistant since this was a very simple change and accumulating FP16 with FP16 can of course have overflow problems, so I assumed that was why it wasn't done. After trying it, it seems fine though? I haven't looked at activation distributions in a lot of models (or even LLM's specifically) to know if this is a good idea generally. Perhaps it would be best to do multiply in FP16/BF16 and acc into FP32 assuming there's enough register space to do it.
HEAD on this fork: https://github.com/arpan-dhatt/mlx