Open fahadh4ilyas opened 7 months ago
So I was able to run it now. Problem was it requires a boolean attention mask whereas the mask used elsewhere is an additive bias (-inf to mask out tokens). That's not a big issue, but I'm still not quite sure this is the best way to go about batching in flash-attn.
Thing is, if I run your example with a 4-bit 7B model, bsz=1 and no attention mask, I get a speed of 142 tokens/second. With attention mask the speed drops to 69 t/s. Similarly for bsz=4 it goes from 121*4 to 60*4 t/s by enabling the attention mask. I'm not sure if the bottleneck is copying+unpadding/re-padding the entire K/V cache for every layer or if flash_attn_varlen_func
is just slower.
If it's the former, perhaps there's a way to keep the cache in an unpadded state throughout the forward pass. Certainly the _get_unpad_data
function only needs to be called once, but I don't think that's the main culprit. I'll have to give it some thought and experiment a little.
I'm also exploring the possibility of just moving away from flash-attn altogether. It's really efficient in most cases, but the lack of attention masking is very limiting, not just in this case but it also holds back speculative decoding.
_get_unpad_data
is actually copied from flash attention github itself here in method named unpad_input
.
I'm sorry. I'm newbie in github. I just want to update my repo following your master branch and it seems like I accidentally close #240 pull requests. I will make new pull requests now here.