turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.18k stars 233 forks source link

What does the implemention of ’segmenting input‘ in the exllamav2 called? #470

Closed laoda513 closed 1 month ago

laoda513 commented 1 month ago

Could you tell me what technique exllamav2 uses for segmenting input? I noticed that when the input is too long, such as 16000, and the default maximum input length is 2048, exllamav2 will process content in segments of 2048 at a time, updating the corresponding kvcache each time, looping through until the entire 16000 is fully processed. After that, each new token is treated as an additional length of 1, and only the lengths of 1 for q and 16000 - 1+1=16000 for kv are calculated. In cases where memory is limited, the maximum length of 2048 is reduced based on available memory to prevent out-of-memory errors.

Currently, the implementation is mathematically equivalent to directly processing all inputs. Does this implementation require any additional operations? After checking the source code, I didn't find any special calculations at the Python level. In the attention forward, some calculations and writing to kvcahe are done through C++, but compared to that, I didn't see any operations on past inputs' kvcache, only calculations for the current input's kv. Is there anything different about this calculation compared to traditional methods?

This implementation brings significant benefits, both in terms of memory and performance. However, I haven't seen similar implementations elsewhere, such as in hf's transformer. Is this technique original to exllamav2? And is it feasible during model training?

laoda513 commented 1 month ago

Seems I have made some mistake here。。

turboderp commented 1 month ago

I don't think it's an especially novel technique, and it isn't something you couldn't do in Transformers (though it isn't in the pipeline implementation as far as I know). I think most implementations have something like it by now, although it took a while to be able to fully use it with flash-attn, SDPA, xformers etc., since it requires a bottom-right aligned causal mask.

The reason for it is that there are diminishing returns to using longer input batches, while the memory required for attention continues to scale quadratically. So if you limit it to batches of, say, 2048 input tokens at a time, you trade off a small amount of prompt ingestion speed for a lot of VRAM savings. By enforcing a limit of k_len * q_len < C, you get a constant upper bound for the size of the attention weights matrix, so you can be "sure" that if attention works for a length of 2048 tokens, it will work for any length.

That's with regular matmul attention. flash-attn doesn't scale the same way, so you're better off just keeping q_len bounded instead of k_len * q_len. I'm not done evaluating xformers to say what the best strategy is there.

ExLlama allocates fixed tensors for anything it can, also aside from attention. For instance in the MLP you need somewhere to put the intermediate up and gate projections before applying the activation function, and ExLlama has a fixed tensor for this (from a pool shared with other modules). This only works when there's an upper bound on the size of the input batch.

I don't think chunking would work during training. I think the computational graph would get very messy, and whatever you saved during the forward pass you'd have to make up for in the backward pass. But I could be wrong.

laoda513 commented 1 month ago

Thank you, I understand.

I was thinking that chunking is essentially similar to breaking down a single inference into multiple consecutive inferences. However, if the kvcache is static, it seems there could be some issues. If we don't consider speed, copying the kvcache each time it is used, and ensuring that the copied kvcache is identical during the gradient checkpoint, might be feasible.

Another way to think about it is, if we don't expect to train on the entire content but only on the last 2000 tokens, then we could precompute the cache for the previous tokens. This way, it might not have any impact on the computation graph.

For models with super long contexts, we are actually only focusing on the response of the last few tokens to the previous text, which also seems to make sense.