Open sorasoras opened 1 month ago
The optimizations that they list on their Github page help with more efficient compute but the bottleneck for GEMV kernels on the hardware that I'm interested in (GPUs, desktop CPUs) is not compute but memory bandwidth. Something like this is only going to be useful on e.g. phones that have comparatively little compute and where conserving battery life is very important.
The specific comparisons they draw are also fishy:
Note that they did not compare with types from #8151, so their numbers are inflated. From my tests, a lookup table was slower than unpacking with SIMD.
See also https://github.com/ggerganov/llama.cpp/pull/7931#discussion_r1646471294, and the associated commit replacing the q22_grid lookup table with SIMD-based unpacking. (although the current implementation in #8151 doesn't shuffle anymore)
But maybe there's something which can still be learned from their implementation. I don't have the same hardware on which they ran comparisons, so I can't directly compare the numbers.
But maybe there's something which can still be learned from their implementation.
Assuming that the gains in compute efficiency are real their technique would be very useful for large batch matrix matrix multiplication or FlashAttention on CPUs where you are compute bound rather than I/O bound. I would still assume that a CPU+GPU system would be better but it would still be a very good addition to llama.cpp I think.
add to that, I would imagine a large 1.58bit model 14B+ would be a lot more compute bound than io bound on CPU or even GPU
@sorasoras @JohannesGaessler Thank you for your interest in T-MAC. Our previous repo (as of the time of this post) may not fully showcase the benefits of T-MAC, and we have made some updates over the last month. To address your concerns here:
The optimizations that they list on their Github page help with more efficient compute but the bottleneck for GEMV kernels on the hardware that I'm interested in (GPUs, desktop CPUs) is not compute but memory bandwidth.
They say that llama.cpp is "dequantization-based" which it definitely is not for batch size 1. Edit: I guess they consider conversion to 8 bit "dequantization".
Yes, we consider the conversion to 8-bit as dequantization, as it converts a low precision datatype to a high precision datatype.
T-MAC evaluation is conducted for GPTQ asymmetric quantization w2g128/w4g128. The bpw are 2.25/4.25 respectively.
We are working on migrating to the latest llama.cpp version. Hopefully, our modifications can be merged into the mainstream.
Thank you for the plots, those look much more convincing.
We agree with you that GEMV is memory bottlenecked. However, the situation changes when GEMV is quantized to low-bit. For instance, after fp16->4-bit/2-bit quantization, the memory footprint is reduced by 4x/8x respectively, but the computation cost remains the same, and there will even be additional dequant overhead.
This is only true if the weights are dequantized to FP16/FP32 and the dot product is calculated using floating point arithmetic. If you convert both the weights and activations to 8 bit you can use SIMD instructions on CPUs that are comparatively faster (make sure to report if things like AVX2 are available on the tested hardware). Similarly on CUDA GPUs the __dp4a
instruction or int8 tensor cores can be used.
llama.cpp has already converted activation to 8-bit, but we still observe that after quantizing weights to lower-bits, the GEMV is more compute bottlenecked.
According to our blackbox profile, LUT SIMD instructions (tbl/pshuf) also have better throughput (CPI) than FMA instructions (dot/fma...). Additionally, T-MAC needs less TBL instructions for lower bits.
Results of AVX2 CPU provided here (Surface Book 3).
We agree with you on the statement regarding CUDA GPUs. From our insights, GPUs aren’t well-suited for LUT due to their limited on-chip memory per core. Conversely, placing LUT on shared memory leads to slow random access caused by bank conflict.
@kaleid-liner I'd be interested in end-to-end performance comparisons between T-MAC and TQ1_0
and TQ2_0
from #8151 with the TriLM 3.9B model.
These don't use lookup tables, but are significantly faster than the other low-bit types in llama.cpp
. (TQ2_0
is around twice as fast as Q2_K
on most platforms)
I think T-MAC is still faster than these improved types, but I did not yet figure out how to build T-MAC on NixOS (especially regarding TVM), and I don't have common machines with the ones tested in the profiling data for T-MAC.
For example, on an AWS t4g.small
instance (2 vcpu Arm Neoverse N1 with 2GB of RAM), when compiling llama.cpp
from #8151 with -mcpu=native
(the default in the Makefile) this is the performance I get with the TriLM 3.9B model at TQ2_0
, TQ1_0
, and Q2_K
(for comparison):
model | size | params | backend | threads | test | t/s |
---|---|---|---|---|---|---|
llama ?B TQ2_0 - 2.06 bpw ternary | 1.08 GiB | 3.99 B | CPU | 1 | pp512 | 5.10 ± 0.02 |
llama ?B TQ2_0 - 2.06 bpw ternary | 1.08 GiB | 3.99 B | CPU | 1 | tg128 | 4.54 ± 0.01 |
llama ?B TQ2_0 - 2.06 bpw ternary | 1.08 GiB | 3.99 B | CPU | 2 | pp512 | 10.34 ± 0.00 |
llama ?B TQ2_0 - 2.06 bpw ternary | 1.08 GiB | 3.99 B | CPU | 2 | tg128 | 8.33 ± 0.02 |
llama ?B TQ1_0 - 1.69 bpw ternary | 946.45 MiB | 3.99 B | CPU | 1 | pp512 | 3.11 ± 0.00 |
llama ?B TQ1_0 - 1.69 bpw ternary | 946.45 MiB | 3.99 B | CPU | 1 | tg128 | 2.87 ± 0.00 |
llama ?B TQ1_0 - 1.69 bpw ternary | 946.45 MiB | 3.99 B | CPU | 2 | pp512 | 6.23 ± 0.00 |
llama ?B TQ1_0 - 1.69 bpw ternary | 946.45 MiB | 3.99 B | CPU | 2 | tg128 | 5.48 ± 0.01 |
llama ?B Q2_K - Medium | 1.43 GiB | 3.99 B | CPU | 1 | pp512 | 2.20 ± 0.00 |
llama ?B Q2_K - Medium | 1.43 GiB | 3.99 B | CPU | 1 | tg128 | 2.06 ± 0.00 |
llama ?B Q2_K - Medium | 1.43 GiB | 3.99 B | CPU | 2 | pp512 | 4.37 ± 0.02 |
llama ?B Q2_K - Medium | 1.43 GiB | 3.99 B | CPU | 2 | tg128 | 3.91 ± 0.02 |
Note that I've used the -r 2
option with llama-bench
to only compute 2 repetitions instead of 5 so the error bound might or might not be greater.
TQ1_0
and TQ2_0
use a block size of 256 elements (like Q2_K
).
I expect T-MAC to still be faster because of its more optimized memory layout and better tiling compared to TQ2_0
(only a single row is computed at a time in the vec_dot
of TQ2_0
; the weights are not interleaved).
@compilade Thanks! I will compare T-MAC against TQ1_0 and TQ2_0. I also expect T-MAC to still be faster because of much less computations.
Prerequisites
Feature Description
https://arxiv.org/pdf/2407.00088 Answer T-MAC (Table-based Matrix-Activation Computation) is an innovative method designed to enable efficient deployment of low-bit Large Language Models (LLMs) on edge devices using CPUs. Here are the key aspects of T-MAC: Purpose: T-MAC addresses the challenge of deploying weight-quantized LLMs on edge devices with limited resources, focusing on efficient mixed-precision matrix multiplication (mpGEMM) without relying on GPUs.
Core Technique: It uses a lookup table (LUT)-based approach to directly support mpGEMM without the need for weight dequantization. This method transforms traditional data-type-centric multiplication into bit-wise table lookup operations. Performance Improvements: Up to 4x increase in throughput compared to llama.cpp 70% reduction in energy consumption For BitNet-b1.58-3B model: 30 tokens/s with a single core on M2-Ultra 71 tokens/s with eight cores on M2-Ultra 11 tokens/s on Raspberry Pi 5
Key Features: Scales linearly with weight bit-width Eliminates multiplications and reduces additions Supports various activation types (fp8, fp16, int8) using fast table lookup and add instructions Implementation Techniques: LUT-centric data layout for efficient on-chip memory usage Table quantization and mirror consolidation to reduce table size Utilization of tbl/pshuf instructions for fast table lookup on CPUs Evaluation: Tested on various edge devices including Apple M2 Ultra, Jetson AGX Orin, Surface Book 3, and Raspberry Pi 5 Achieved up to 6.6x speedup (average 3.6x) compared to llama.cpp End-to-end LLM inference speedup of 2.8x for Llama-2-7B-2bit model
Significance: T-MAC provides a practical solution for deploying LLMs on edge devices using widely available CPUs, making LLM inference speed on CPUs comparable or even superior to GPUs on the same devices in some cases.
Availability: The T-MAC system is open-sourced and available on GitHub for further development and implementation.
Motivation
Looks like a good addition to current Bitnet 1.58bit to speed it up even further
Possible Implementation
https://github.com/microsoft/T-MAC