turboderp / exllama

A more memory-efficient rewrite of the HF transformers implementation of Llama for use with quantized weights.
MIT License
2.67k stars 214 forks source link

Question about sampling and kernel fusion #234

Closed sleepwalker2017 closed 11 months ago

sleepwalker2017 commented 11 months ago

I noticed that the sampling stage for a batched input uses for-loop to calculate each item. Is that always the fact? For large batches ,the loop sends many small cuda kernels which is inefficient.

BTW, as far as I know, the code uses separate GEMMs to calculate proj_q, proj_k, proj_v, seems there is no kernel fusion in the implementation, is that the fact?

I want to confirm that I didn't miss some configurations to measure the performance wrongly. Thank you!

 def batched_sample(self, logits, temperature, top_k, top_p, min_p, typical, num = 1):

        if logits.shape[0] == 1: return self.sample(logits, temperature, top_k, top_p, min_p, typical, num)

        samples = []
        scores = []
        for i in range(logits.shape[0]):
            t, s = self.sample(logits[i, :, :], temperature, top_k, top_p, min_p, typical)
            samples.append(t)
            scores.append(s)

        return torch.cat(samples, dim = 0), torch.cat(scores, dim = 0)
turboderp commented 11 months ago

noticed that the sampling stage for a batched input uses for-loop to calculate each item

It does, yes. This is because the batched sampling was added in a bit of a rush. But sampling is done on the CPU anyway.

Keep in mind that even for a "large" batch of say 16, you've only got 16x32000 floats to work on, and you're not doing any type of work that really calls for the lock-step parallelism that CUDA is good at. You do a lot of two-step normalization, sorting and indexing, most of which works way better on the CPU.

There is one C++ extension function for applying the repetition penalty, which I guess highlights just how trivial (and quick) a process like that is compared to the gymnastics required to do the same thing on (GPU) tensors. There are more of those operations that I was going to implement in C++ but just haven't gotten around to.

BTW, as far as I know, the code uses separate GEMMs to calculate proj_q, proj_k, proj_v, seems there is no kernel fusion in the implementation, is that the fact?

There is lots of kernel fusion in the implementation, just not that one in particular. Some implementations combine Q, K and V projections because they're all multiplied by the hidden state at the same point. But the main benefit isn't having fewer kernel launches, it's having less Python code in between those kernel launches. If you're launching kernels back to back in C++ the difference on the CPU side is on the order of a few microseconds, while on the GPU, a 4096x4096 matrix (say) is already large enough that whatever you do with it will take almost precisely 1/3 of the time it would take to do the same thing to a 4096x12288 matrix. Especially since most of the time you'll be multiplying by a tiny hidden state that easily fits in the cache.

But more importantly, this is GPTQ. and those matrices will be quantized with different act-order permutations. ExLlama relies on reordering the hidden state going into each linear layer, and that reordering would be different for q_proj, k_proj and v_proj.

sleepwalker2017 commented 11 months ago

noticed that the sampling stage for a batched input uses for-loop to calculate each item

It does, yes. This is because the batched sampling was added in a bit of a rush. But sampling is done on the CPU anyway.

Keep in mind that even for a "large" batch of say 16, you've only got 16x32000 floats to work on, and you're not doing any type of work that really calls for the lock-step parallelism that CUDA is good at. You do a lot of two-step normalization, sorting and indexing, most of which works way better on the CPU.

There is one C++ extension function for applying the repetition penalty, which I guess highlights just how trivial (and quick) a process like that is compared to the gymnastics required to do the same thing on (GPU) tensors. There are more of those operations that I was going to implement in C++ but just haven't gotten around to.

BTW, as far as I know, the code uses separate GEMMs to calculate proj_q, proj_k, proj_v, seems there is no kernel fusion in the implementation, is that the fact?

There is lots of kernel fusion in the implementation, just not that one in particular. Some implementations combine Q, K and V projections because they're all multiplied by the hidden state at the same point. But the main benefit isn't having fewer kernel launches, it's having less Python code in between those kernel launches. If you're launching kernels back to back in C++ the difference on the CPU side is on the order of a few microseconds, while on the GPU, a 4096x4096 matrix (say) is already large enough that whatever you do with it will take almost precisely 1/3 of the time it would take to do the same thing to a 4096x12288 matrix. Especially since most of the time you'll be multiplying by a tiny hidden state that easily fits in the cache.

But more importantly, this is GPTQ. and those matrices will be quantized with different act-order permutations. ExLlama relies on reordering the hidden state going into each linear layer, and that reordering would be different for q_proj, k_proj and v_proj.

Seems all is ok. I'm not familiar with GPTQ, so I'll learn it to understand the last question. Thank you for the detailed answer!

sleepwalker2017 commented 11 months ago

noticed that the sampling stage for a batched input uses for-loop to calculate each item

It does, yes. This is because the batched sampling was added in a bit of a rush. But sampling is done on the CPU anyway.

Keep in mind that even for a "large" batch of say 16, you've only got 16x32000 floats to work on, and you're not doing any type of work that really calls for the lock-step parallelism that CUDA is good at. You do a lot of two-step normalization, sorting and indexing, most of which works way better on the CPU.

There is one C++ extension function for applying the repetition penalty, which I guess highlights just how trivial (and quick) a process like that is compared to the gymnastics required to do the same thing on (GPU) tensors. There are more of those operations that I was going to implement in C++ but just haven't gotten around to.

BTW, as far as I know, the code uses separate GEMMs to calculate proj_q, proj_k, proj_v, seems there is no kernel fusion in the implementation, is that the fact?

There is lots of kernel fusion in the implementation, just not that one in particular. Some implementations combine Q, K and V projections because they're all multiplied by the hidden state at the same point. But the main benefit isn't having fewer kernel launches, it's having less Python code in between those kernel launches. If you're launching kernels back to back in C++ the difference on the CPU side is on the order of a few microseconds, while on the GPU, a 4096x4096 matrix (say) is already large enough that whatever you do with it will take almost precisely 1/3 of the time it would take to do the same thing to a 4096x12288 matrix. Especially since most of the time you'll be multiplying by a tiny hidden state that easily fits in the cache.

But more importantly, this is GPTQ. and those matrices will be quantized with different act-order permutations. ExLlama relies on reordering the hidden state going into each linear layer, and that reordering would be different for q_proj, k_proj and v_proj.

I think you are right, fusing proj_q proj_k proj_v kernels doesn't reduce the cost of writing and reading data from memory.

BTW, I may want to ask about kernel optimization if you'd like to give some advice: how can we judge whether a kernel is optimized well enough? For memory-intensive programs like GEMV, how do we get its upper bound? If you use memory throughput to compare theoretical bandwidth, it doesn't seem right, because the GPU has multiple layers of memory, and the theoretical bandwidth is only about DRAM. The two metrics seem incomparable. Do you have any experience or knowledge?

Thank you!!

sleepwalker2017 commented 11 months ago

Hi, as for single batch inference, the performance is good, about 2x speedup compared with FasterTransformer.

When I run the example_batch.py and make the batch as 64. Then the sample stage will run 64 32000 log(top_k) loops, that's really not a small cost.

Here is the nsight system profiling result. The sample stage takes 51ms, that's more than 50% of the cost of 40 layer decoder. Since it's run for every token, I think the sample stage will impact the performance largely.

As decoder is composed of gemv kernels, it's memory bound, so users will always use batching to increase the compute intensity.

image

noticed that the sampling stage for a batched input uses for-loop to calculate each item

It does, yes. This is because the batched sampling was added in a bit of a rush. But sampling is done on the CPU anyway.

Keep in mind that even for a "large" batch of say 16, you've only got 16x32000 floats to work on, and you're not doing any type of work that really calls for the lock-step parallelism that CUDA is good at. You do a lot of two-step normalization, sorting and indexing, most of which works way better on the CPU.

There is one C++ extension function for applying the repetition penalty, which I guess highlights just how trivial (and quick) a process like that is compared to the gymnastics required to do the same thing on (GPU) tensors. There are more of those operations that I was going to implement in C++ but just haven't gotten around to.

BTW, as far as I know, the code uses separate GEMMs to calculate proj_q, proj_k, proj_v, seems there is no kernel fusion in the implementation, is that the fact?

There is lots of kernel fusion in the implementation, just not that one in particular. Some implementations combine Q, K and V projections because they're all multiplied by the hidden state at the same point. But the main benefit isn't having fewer kernel launches, it's having less Python code in between those kernel launches. If you're launching kernels back to back in C++ the difference on the CPU side is on the order of a few microseconds, while on the GPU, a 4096x4096 matrix (say) is already large enough that whatever you do with it will take almost precisely 1/3 of the time it would take to do the same thing to a 4096x12288 matrix. Especially since most of the time you'll be multiplying by a tiny hidden state that easily fits in the cache.

But more importantly, this is GPTQ. and those matrices will be quantized with different act-order permutations. ExLlama relies on reordering the hidden state going into each linear layer, and that reordering would be different for q_proj, k_proj and v_proj.

turboderp commented 11 months ago

ExLlama isn't really intended for datacenters, though. The target use case is a single client, like a local chatbot or some such.

For e.g. a 33B model running on a 24 GB GPU, there's only a few GB left over after the weights, so running multiple batches really isn't an option unless you severely limit the available sequence length. And while it'd be nice to be able to batch forward passes and turn all those GEMV operations into GEMM, it doesn't really help for a single user who can't begin generating a new token until the previous token is chosen.

In this situation, if you can manage, say, 100 tokens per second, then the sampling code is run 100 times per second on a single 1x32000 logit vector, so that's the sensible case to optimize for.

The sampling code isn't especially complex, though, and you can easily do as oobabooga did in TGW, keeping the output logits in VRAM and passing them on to HuggingFace samplers. They do it for consistency between AutoGPTQ and ExLlama, of course, not for efficiency, since TGW is also a single-user focused application.

As for kernel optimization, I don't know when a kernel is optimized well enough. You can work out a lower bound for minimum latency given just the size of the model. E.g. on a 4090 with 1 TB/s theoretical memory bandwidth, and a 4 GB model, you won't be able to surpass 250 tokens/second. Roughly speaking. You have to account for the embedding layer which only supplies one row per token (and ExLlama keeps it in system RAM anyway), but on the other hand you have a bunch of other VRAM access going on as well. So with attention, normalization, etc., I think ExLlama's quantized matmul kernel is pretty close to optimal for the single-token case.

For multiple tokens it becomes more complicated, of course. Once the hidden state doesn't fit so easily in the L1 cache, you have to start worrying about strategies for making good use of shared memory. This is one thing I'm working on for V2, mostly for the sake of speculative sampling where you're processing maybe 4 tokens at a time.

sleepwalker2017 commented 11 months ago

ExLlama isn't really intended for datacenters, though. The target use case is a single client, like a local chatbot or some such.

For e.g. a 33B model running on a 24 GB GPU, there's only a few GB left over after the weights, so running multiple batches really isn't an option unless you severely limit the available sequence length. And while it'd be nice to be able to batch forward passes and turn all those GEMV operations into GEMM, it doesn't really help for a single user who can't begin generating a new token until the previous token is chosen.

In this situation, if you can manage, say, 100 tokens per second, then the sampling code is run 100 times per second on a single 1x32000 logit vector, so that's the sensible case to optimize for.

The sampling code isn't especially complex, though, and you can easily do as oobabooga did in TGW, keeping the output logits in VRAM and passing them on to HuggingFace samplers. They do it for consistency between AutoGPTQ and ExLlama, of course, not for efficiency, since TGW is also a single-user focused application.

As for kernel optimization, I don't know when a kernel is optimized well enough. You can work out a lower bound for minimum latency given just the size of the model. E.g. on a 4090 with 1 TB/s theoretical memory bandwidth, and a 4 GB model, you won't be able to surpass 250 tokens/second. Roughly speaking. You have to account for the embedding layer which only supplies one row per token (and ExLlama keeps it in system RAM anyway), but on the other hand you have a bunch of other VRAM access going on as well. So with attention, normalization, etc., I think ExLlama's quantized matmul kernel is pretty close to optimal for the single-token case.

For multiple tokens it becomes more complicated, of course. Once the hidden state doesn't fit so easily in the L1 cache, you have to start worrying about strategies for making good use of shared memory. This is one thing I'm working on for V2, mostly for the sake of speculative sampling where you're processing maybe 4 tokens at a time.

ok, thank you for the answering. I have no doubt now.