turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.22k stars 238 forks source link

Can you help me split EXL2 weights for multi gpu? #257

Closed bob-just-bob closed 3 weeks ago

bob-just-bob commented 6 months ago

Hi, great work here. Thanks a lot.

I would like to be able to split exl2 weights so I can distribute exl2 quantized linear layers across gpus.

goal:


                +----------+
                |          | q_weights,
                |          | q_groups,
                |          | ...
                |          |
                +----------+
                  q_tensors
                  |      |
                  |      |  "shard", split q_weights in the second dimension?
                  |      |  - rearrange perm ... ?
                  v      v
             +------+  +-------+
             |      |  |       | q_weights (1/2),
             |      |  |       | q_groups (whole?)
             +------+  +-------+
               LL1       LL2
              device1   device2

          ...gather output(outputs...)

Happy to hear your pointers on how to implement it!

turboderp commented 6 months ago

I haven't gotten around to tensor parallelism yet, though it is somewhere on the roadmap. Maybe it's for V3, idk. But you're right, that would be the way to do it. I guess I can go over the various components briefly, as they're used by the matmul kernel.

So, the operation is a @ b -> c where b is the quantized matrix and all operands are row-major. The arguments, in order:

a: FP16 tensor in row-major format, size (m, k). The data needs to be contiguous, although a stride parameter could be added easily there.

b_q_weight: Quantized weights, as packed integers. Conceptually the size is (k, n), but it's stored as a uint32 matrix of size (k', n) where k' depends on the bitrate which varies across groups. It's row-major overall, but every element packs a slice of one column. The packing is different for every bitrate:

- 2-bit group: b_q_weight[  0, 0] <- shuffle <- b'[0:16, 0]
- 3-bit group: b_q_weight[0:3, 0] <- shuffle <- b'[0:32, 0]
- 4-bit group: b_q_weight[  0, 0] <- shuffle <- b'[0: 8, 0]
- 5-bit group: b_q_weight[0:5, 0] <- shuffle <- b'[0:32, 0]
- 6-bit group: b_q_weight[0:3, 0] <- shuffle <- b'[0:16, 0]
- 8-bit group: b_q_weight[  0, 0] <- shuffle <- b'[0: 4, 0]

b' here means the permutation b[perm, :]. This is to allow act-order and group size to function at the same time without every row having a new group index. So what the kernel actually computes is a[:, perm] @ b[perm, :] = a @ b

b[perm, :] is, by the way, the state that the weights tensor is in right after the rows have been quantized in descending order of activation. In GPTQ, the inverse permutation would be applied before packing the weights, but EXL2 skips that step.

Constructing a[:, perm] in SMEM is cheap as it turns out, but of course it presents a bit of an obstacle to tensor parallelism. You can split b' along the k dimension (at a group boundary) but you'd still have to index into all of a to get the corresponding slice of a[:, perm]. Since a is going to be relatively small, though, scattering the whole tensor is probably not going to be a bottleneck.

The shuffle operation only reorders bits within each vertical slice for more efficient unpacking/dequantizing.

b_q_scale: The quantized 4-bit scales for each group, conceptually size (g, n) where g is the number of groups, packed as uint32 elements in size (g, n/8).

b_q_scale_max: Size (n), dtype FP16. The maximum scale per column, premultiplied by 1/256.

Quantization is symmetric (there's a whole discussion to be had about that, I guess), which means there's no offset or q_zero as it'd be referred to for GPTQ. The formula for dequantizing is: (q_scale^2 / 256 * q_scale_max) * (q_weight - 2^(bitrate - 1))

c: FP16 output, row-major, size (m, n)

size_m, size_n, size_k: Shape of the operation

groups: Total number of groups

b_q_group_map: uint16 list of size (2 * k) used by each thread to find its group index from its k index (necessary since you can have multiple group sizes in one quantized tensor). It alternates between group index and the count of rows (along the k dimension) remaining in each group. E.g. for the first group with group size 32 it would look like:

0, 32, 0, 31, 0, 30 ... 0, 2, 0, 1, 1, 32, 1, 31, 1, 30 ... 1, 2, 1, 1, 2, 32, 2, 31 ...

b_q_perm: The k permuation that's already pre-applied to b and is applied on-the-fly to a. uint16, size (k)

rows_8: The number of rows (along k) that use bitrate 8.

rows_6: The number of rows (along k) that use bitrate >= 6.

rows_5: The number of rows (along k) that use bitrate >= 5.

rows_4, rows_3, rows_2: etc.

clear: Boolean flag determining whether the output buffer is zeroed at the start of the operation. Since the output is accumulated directly into global memory with atomicAdd, you can add in residual connections for free this way.

r_weights, r_weights_stride: FP16 tensor of size (m, 1) stored as (m, stride). Used in MoE layers for multiplying each output row by a separate routing weight. Also allows the kernel to exit immediately if all the routing weights are zero.


I think that's it for the quantized matmul kernel.

When m is larger than MAX_Q_GEMM_ROWS defined in exllamav2/exllamav2_ext/config.h, the function switches to reconstructing an FP16 matrix in GMEM and deferring the matmul to cuBLAS. This is much more efficient since the custom kernel is very GEMV-focused and doesn't scale well past m > ~32.

This reconstruction also applies the same permutation, so that what's actually written is b = b'[invperm, :]. I suppose that would be another slight obstacle to splitting the tensors, since you'd end up with a row-sparse matrix of size (k, n) if you split the weights along k. So it would make sense not to un-permute it and instead permute a in global memory before invoking the cuBLAS GEMM. That should still be reasonably fast.

To start with you could set MAX_Q_GEMM_ROWS to 1000000 to focus on the custom kernel first, or set it to 0 to disable the custom kernel and force the reconstruction instead.

I guess it would be much simpler to split the n dimension, since you would have to scatter all of a anyway. In that case you'd have to replicate all of the q tensors across devices, but q_weights, q_scale and q_scale_max could be sliced at any column index that's a multiple of 32 (though that should probably be 64 for better ROCm performance.)

Anyway, I hope that helps. Also nice to have this written down for reference, since it is getting a little complicated and I kind of lose track myself sometimes.

bob-just-bob commented 6 months ago

Thanks a lot for the write up! Sounds tackable. Potentially I'll try and report in comming weeks.