turboderp / exllama

A more memory-efficient rewrite of the HF transformers implementation of Llama for use with quantized weights.
MIT License
2.72k stars 214 forks source link

batch inference doesn't improve performance compared to sequential #150

Closed nivibilla closed 1 year ago

nivibilla commented 1 year ago

While testing out the batched inference on openllama, I notice that for a single prompt it takes 2.56 seconds but for 8 prompts it takes 24.62 secs. Essentially no improvement in performance.

I might be misunderstanding but shouldn't it take somewhere close to 3-5secs?

For comparison using the huggingface batch inference in 4 bit does 512 prompts in approx 4.5 min, which comes to about 4 sec per batch of 8.

nivibilla commented 1 year ago

Ah I see

https://github.com/turboderp/exllama/blob/e61d4d31d44bd4dd3380fc773509c870ba74cb9f/generator.py#L62

It is just looping over the prompts sequentially. Makes sense that there is not an improvement in speed. I guess I need to figure out how the huggingface guys are doing it and then maybe copy the code over to here.

aljungberg commented 1 year ago

https://github.com/turboderp/exllama/blob/e61d4d31d44bd4dd3380fc773509c870ba74cb9f/generator.py#L62

It is just looping over the prompts sequentially. Makes sense that there is not an improvement in speed. I guess I need to figure out how the huggingface guys are doing it and then maybe copy the code over to here.

You're looking at the sampler, which indeed works sequentially but is a tiny fraction of total generation time.

In my testing a batch size of 5 is over 3 times as fast (measured in tokens delivered per second). Keep in mind that if your context size is long that is already "batched" (in a sense) so if you're already maxing out your compute there, batching won't help.

juliensalinas commented 1 year ago

I just tested the example_batch.py script in the following conditions:

The generation of the 4 examples took around 35 seconds on average.

The I tested with the first example only ("Once upon a time,"): the generation took around 12 seconds on average.

nikshepsvn commented 1 year ago

The improvement batching gives increases greatly with batch size but then each batch needs to be smaller to fit into memory, its a hard position to be in given that exllama is very optimized for consumer GPUs with somewhat limited vRAM but if you try it out on larger vRAM cards (like the A6000) with batch_size over 6+ you will see bigger differences

nivibilla commented 1 year ago

Thanks everyone, I was testing on an A10G. So maybe I didn't have enough VRAM to see the benefits. I will test this out on a bigger GPU and get back

turboderp commented 1 year ago

Keep in mind that if your context size is long that is already "batched" (in a sense) so if you're already maxing out your compute there, batching won't help.

This is correct for prompt processing, where processing 1x2000 tokens is more or less the same operation as 10x200 tokens, and either would about ten times as long as processing 1x200 tokens. Once it gets to generating tokens one at a time, batching should make a big different regardless of the sequence length, I guess until it gets so long that attention becomes a bottleneck.

Batching can be counterproductive if prompts have very different lengths, since they have to be padded to the same length. And if the completions end up being different lengths you can't stop generating any of them until they've all produced an EOS token.

SinanAkkoyun commented 1 year ago

Hey, right off the bat, thank you SO much for the awesome repo, the speedups are incredible!!! I wanted to ask when approximately efficient batching will be on the roadmap? If soon, I'd like to work on an implementation of continuous batching just like HFs text generation inference can do

nivibilla commented 1 year ago

@SinanAkkoyun im using vLLM at the moment. Which is way faster than HF TGI, they use paged attention. https://github.com/vllm-project/vllm

They have plans to support exllama in the future. I'm not good enough at C++ to implement it myself. But maybe this is of interest to you @turboderp

SinanAkkoyun commented 1 year ago

@nivibilla How awesome is that, thank you for the info! :)

turboderp commented 1 year ago

I've looked at vLLM before and it is pretty fast, it just doesn't have any options for quantization. Paged attention looks complex, but as far as I can tell it doesn't accomplish much beyond what ExLlama already does with a preallocated cache. The main benefit in either case would be avoiding the constant concatenations of the very large cache tensors that HF models rely on, leading to significant overhead and memory fragmentation.

The more dynamic allocation scheme looks smart, but I'm not convinced it actually achieves anything. If you don't have enough VRAM for the full-length cache at the maximum batch size, all it seems to do is set you up for a crash when you eventually do need to use the full context/batch size. It seems to me that, especially for a server, you'd want all the space allocated up front so you can properly define the range of valid inputs to the forward pass.

If you're serving a bunch of clients that might make very similar requests, I can see how caching common prefixes could be beneficial. If every chat starts with "This is a chat between a user and a helpful AI assistant..." then re-using the keys and values resulting from that string makes sense to get you some very situational speedups. ExLlama's generator has a similar function, though it's only meant for a single sequence and I'm not actually sure what it does with a batch. I've also considered expanding upon it as a way to make memory-efficient beam search faster, but since I haven't had very good results with beam search in general it became a low priority.

nivibilla commented 1 year ago

That makes sense. It's quite useful for me in a batch usecase where I use the same prompt prefix(where I explain what the rules are ) over and over with only 10% difference(the actual sentence) for sentiment analysis. So the paged attention helps me here a lot.

nivibilla commented 1 year ago

And yes they don't have quantisation, the only major drawback

turboderp commented 1 year ago

What does

-- Inference, first pass. Time, Inference: 0.14 seconds Speed: 13786.21 tokens/second

refer to? (code:)

This measures the inference speed for a long sequence. By default it's 1920 tokens through a single forward pass. This completes in 0.14 seconds so that works out to those ~14k tokens/second.

It's much faster than the individual token speed because it's a batched operation. All those 1920 tokens are processed in parallel, with the same weights applied at each step so you save a ton of VRAM bandwidth, which more than makes up for the slower process of first de-quantizing each matrix to do a regular (cuBLAS) GEMM, as opposed to using the matrix-vector optimized quantized kernel.

This is only useful for "prompt" processing, though. Once you start generating tokens, the input to the forward pass has a length of one. Although you can of course run batches of multiple sequences with a length of one, to get some of the benefit back.

SinanAkkoyun commented 1 year ago

@turboderp Thank you very much for responding, despite me deleting my comment, the detailed answers really help! :) Do you think batched performance will be enhanced by you in the near future? TGI seems to almost have zero penalty for batch sized > 1. It would be super interesting for a variety of concepts like tree of thought implementations and speculative decoding

SinanAkkoyun commented 1 year ago

Also may I ask, what is the biggest bottleneck at the moment with batching? I tried to look into it but I am not sufficient enough in ML to get it

turboderp commented 1 year ago

@SinanAkkoyun It's always hard to say why any given implementation performs better in situation X than in situation Y on hardware setup Z. TGI is quite slow in general compared to, say, vLLM, which in turn is quite a bit slower than ExLlama. One possible explanation why TGI performs the same with batch sizes 1 and 2 could simply be that it isn't properly optimized for a batch size of 1.

Another point to consider is this: Take a 33b model, let's say. A batch size of 1 will have you doing a whole bunch of [1, 6144] @ [6144, 6144] matmuls for your Q, K, V, O projections. The MLP matmuls are similar, just even larger matrices on the right hand side. Increase the batch size to two and those same matmuls take a shape of [2, 6144] @ [6144, 6144] instead. That's twice the amount of computation but only negligibly more VRAM access. So, the more memory-bound you are, the smaller the difference between those two operations. Any FP16 implementation is going to be about four times as heavy as GPTQ on the VRAM access to begin with.

So the fact that TGI is as fast (or perhaps rather, as slow) on batch sizes 1 and 2, doesn't mean ExLlama isn't already optimal in both cases. That's not to say it couldn't potentially be optimized further. But GEMM is very non-trivial to optimize, and you really want to treat it separately from GEMV. Batch sizes of around 2-8 are kind of a difficult terrain in-between the two. For the best results you're probably looking at switching between 3-4 completely different approaches. atomicAdd, reduction, computing multiple rows per thread (for very large k), various hybrid techniques etc. No guarantee that it will perform significantly better in any case, but probably you could squeeze out a little more performance, yes.

There are also changes you could make on a higher level, such as catering to batch sizes of different length more intelligently than by just padding, especially relevant if you have sequences of wildly differing length, as you might see in CFG etc. Also, beam search and other techniques that run inferences on batches where most of the sequence is identical across the batch could benefit from some logic that tries to eliminate redundancy.

I'm compiling a number of these ideas for ExLlama V2, though, rather than turning ExLlama into even more of a mess than it already is. I mean, it's still supposed to be an experimental work-in-progress, but apparently there's quite a few users now so I don't want to keep breaking it.

SinanAkkoyun commented 1 year ago

Thank you a million for your reply, that means a lot to me! It makes lots of sense that TGI and alike are not optimized for a batch size of 1 to begin with, whereas Exllama already performs great. I really like your project and your attention to your userbase. Do you have a way of supporting you a bit?

turboderp commented 1 year ago

I don't really need donations but you're not the first person to ask, so I set up a ko-fi profile just cause why not.

SinanAkkoyun commented 1 year ago

Hey!

So the fact that TGI is as fast (or perhaps rather, as slow) on batch sizes 1 and 2, doesn't mean ExLlama isn't already optimal in both cases.

I just ran some vLLM and Exllama batching tests and found that while vLLM can handle >64 and more whilst not OOMing on a single 4090, Exllama OOMs at over 18 batches (every batch is "Once upon a time,", 200 max output tokens)

These are my findings:

runpod 4090 exllama:
18 batches "Once upon a time,":
speed: 30.35 tps
effective speed: 546.23 tps
mem: 22,086.84 MB

vLLM:

runpod RTX 4090 vLLM :
(bsz 1): 59.38 TPS (13.921 GB)
(bsz 17): 48.54 TPS (22.529 GB)
(bsz 32): 41.28 TPS
(bsz 64): 32.53 TPS (22.675 GB) 

I find both implementations fascinating, yours for being so extremely fast with bsz=1 and vllm for handling so much throughput at reasonable speeds

I wanted to ask (for Exllama V2) if there is room for improvement regarding batched size and speed? I don't really know the differences in detail but maybe one can benefit from the other? That would be awesome! :)

nivibilla commented 1 year ago

@SinanAkkoyun that's interesting, could you share the code for that pls?

SinanAkkoyun commented 1 year ago

@nivibilla sure: https://gist.github.com/SinanAkkoyun/5bb69b0988231eb20896790b2d81e087

nivibilla commented 1 year ago

Thanks @SinanAkkoyun I was actually wanting to see the exllama one. As I'm struggling to get the speeds.

SinanAkkoyun commented 1 year ago

Oh lol, no problem, I just used the example_batch.py! I get such speeds because I rented a runpod instance with 4090 and a 12900k

turboderp commented 1 year ago

If you want a better apples-to-apples comparison you should run with a sequence length matching what you intend to generate. ExLlama in this case runs out of VRAM by pre-allocating a cache large enough for the full sequence length at the given batch size, whereas vLLM would only run out as you're actually starting to use the full sequence length. If I set config.max_seq_len to 250 I can easily run with a batch size of 64 and a little beyond as well:

bsz 1 : Speed: 154.426 t/s, effectively: 154.426 t/s
bsz 2 : Speed: 122.197 t/s, effectively: 244.394 t/s
bsz 4 : Speed: 75.634 t/s, effectively: 302.537 t/s
bsz 17 : Speed: 36.283 t/s, effectively: 616.818 t/s
bsz 32 : Speed: 32.193 t/s, effectively: 1030.160 t/s
bsz 64 : Speed: 24.872 t/s, effectively: 1591.807 t/s
bsz 96 : Speed: 20.346 t/s, effectively: 1953.244 t/s

ExLlama still loses out on speed for large batches, for two reasons: First, vLLM as I understand it "cheats" a little with paged attention. I.e. if you run the same prompt multiple times in a batch it may reuse the cache pages identical prompts. It's questionable if that's a representative test, unless you're planning to run a back-end server which might process a lot of concurrent prompts that start the same way. Which is arguably fair, I guess.

Second reason is more straightforwardly fair. At large enough batch sizes, the most efficient thing for ExLlama to do is convert each matrix to FP16 and then perform a regular cuBLAS matmul, instead of doing the GPTQ matmul which is more efficient in terms of VRAM bandwidth. At this point it's doing pretty much the same thing as vLLM, just with the extra overhead for the conversion, so it's going to be strictly slower. It still needs less VRAM overall, since the converted matrices are temporary and the ones permanently kept in VRAM are still 4-bit GPTQ.

I would say, if you're actually planning to run with batch sizes of 64 for full-length sequences, you're looking at more serious server hardware anyway, and vLLM or the Transformers text generation server with its tensor parallelism and whatnot are much better suited for a task like that.

SinanAkkoyun commented 1 year ago

If you want a better apples-to-apples comparison you should run with a sequence length matching what you intend to generate.

The thing is that I would need some automated dynamic way of allocating (just like vLLM), if I fix the max_seq_length to 250 but the model wants to continue, I am out of luck Or I am just too dumb to think of some dynamic implementation, it would mean the world to me if you could help me with that

each matrix to FP16 and then perform a regular cuBLAS matmul, instead of doing the GPTQ matmul

Ah okay I see

Yes, for deployment it would nonetheless make more sense to deploy vLLM because they also have weak CPUs which deminishes the benefit of exllama anyways. But for local usage, I am still super interested in dynamic batch sizes (for tree of thought and speculative decoding)

turboderp commented 1 year ago

Dynamic batch sizes have the problem that you're not properly defining any limits on your inputs. I don't see much sense in leaving it up to a probabilistic model to decide if and when you're going to run out of memory.

Maybe you could add extra logic to reshape the cache when it begins to grow too large, dynamically split the operation into smaller batches and swap some of those ongoing batches to system RAM to be finished sequentially. But it gets complex. ExLlama V2 might have a system like that, but in the meantime there isn't much of a penalty for just allocating a new cache for any given task.

SinanAkkoyun commented 1 year ago

Based on https://github.com/turboderp/exllama/issues/182#issuecomment-1646595220 (I also did a PR so you don't have to do it) all my needs are fulfilled. I can simply allocate the batch size at which it would OOM and run lower numbers of batches easily. (and when bsz > max_bsz, I would just process them in series) If I got it correctly, then generation speed of one single batch is not altered by max_seq_length, only the mem usage right?

183 gives us one fundamental block for implementing an API. But this is missing:

When a generation is currently running and a new API request is being made (and this is what I understand under 'dynamic batching'), it would be ideal if the new generation could somehow be parallely injected into the current generation (in the sense of two batches). Before I waste 2 days misinterpreting your code because I'm too inproficient, I wanted to ask if this is at least possible?

(a thought of me: when injecting in the middle of a sequence, the max_new_tokens can only be 1/2 of the original one?)

nivibilla commented 1 year ago

Hey @turboderp is the exllamav2 public or you have some sort of roadmap? Would love to try contribute

turboderp commented 1 year ago

No, it's currently just a bunch of loose experiments. It'll be a while yet before there's anything to make public. It also depends what happens in the meantime. Easy to get distracted with everything that's going on.

nivibilla commented 1 year ago

I see. No problem. Thanks everyone for your input. Closing this issue as I understand the difference now. Thanks

SinanAkkoyun commented 1 year ago

I wanted to ask if this is at least possible?

Do you perhaps know if it's possible to easily inject generation of a new output mid generation of anotherone?

turboderp commented 1 year ago

Not easily, no. You'd have to construct a new attention mask that takes into account this new generation starting at a later position than the previous ones. And you'd only have whatever space is left in the total sequence length, depending on when exactly the new generation is started relative to the others in the batch. You could pause the batch to generate as much as you need for the new sequence until it lines up with the rest of the batch, then combine the K/V caches... I mean, there's plenty of options, but none of them straightforward.

SinanAkkoyun commented 1 year ago

I see, thank you! Sounds really complicated

turboderp commented 1 year ago

Yeah, it comes down to producing an efficient parallelized operation for the GPU, and when adding steps to that the complexity tends to blow up.

Same with outliers, for instance. You can gain a significant improvement in perplexity by selecting just a few weights per matrix to include in full precision instead of quantizing them. But even if it's just a little bit of extra data, CUDA relies on synchronization between threads so you can't just insert logic like "if this weight is a special weight, treat it differently" without a lot of extra synchronization steps that slow everything to a crawl.

So these things end up being architectural decisions you really have to make quite early on. I think ExLlama V2 might end up having some design choices that could help here, though. E.g. it's pretty easy to dynamically size a batch for the purposes of all the linear layers, it's just attention that presents issues. So if the forward pass supported multiple contexts, you could combine those contexts most of the time and only process them separately where they need to be separate.