turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.28k stars 243 forks source link

Batched flash attention #243

Open fahadh4ilyas opened 7 months ago

fahadh4ilyas commented 7 months ago

I'm sorry. I'm newbie in github. I just want to update my repo following your master branch and it seems like I accidentally close #240 pull requests. I will make new pull requests now here.

turboderp commented 6 months ago

So I was able to run it now. Problem was it requires a boolean attention mask whereas the mask used elsewhere is an additive bias (-inf to mask out tokens). That's not a big issue, but I'm still not quite sure this is the best way to go about batching in flash-attn.

Thing is, if I run your example with a 4-bit 7B model, bsz=1 and no attention mask, I get a speed of 142 tokens/second. With attention mask the speed drops to 69 t/s. Similarly for bsz=4 it goes from 121*4 to 60*4 t/s by enabling the attention mask. I'm not sure if the bottleneck is copying+unpadding/re-padding the entire K/V cache for every layer or if flash_attn_varlen_func is just slower.

If it's the former, perhaps there's a way to keep the cache in an unpadded state throughout the forward pass. Certainly the _get_unpad_data function only needs to be called once, but I don't think that's the main culprit. I'll have to give it some thought and experiment a little.

I'm also exploring the possibility of just moving away from flash-attn altogether. It's really efficient in most cases, but the lack of attention masking is very limiting, not just in this case but it also holds back speculative decoding.

fahadh4ilyas commented 6 months ago

_get_unpad_data is actually copied from flash attention github itself here in method named unpad_input.