Closed Ph0rk0z closed 2 weeks ago
Are you referencing this block here?
Yea. I transposed K+V to get it to inference but no go. Just get gibberish. Same for SDP. Must be missing something. xformers seemed like the inference got faster too. SDP it was the same. These functions expect q/k/v to all be the same size, for flash attention it takes care of that.
Any thoughts on flash infer?
https://github.com/flashinfer-ai/flashinfer
Looks like it has the same expectation, just looking at the example.
Doesn't it have it's own kernels? I dunno if it can be shoehorned in but I didn't look too hard. Benefit of xformers is working on non-ampere cards.
xformers support has been added. For what it's worth. (:
The latest git xformers changed where you import the mask from so that will be coming up.
It's xformers.ops.fmha.attn_bias. Found out when I rebuilt it for torch 2.3.1
Tested out your xformers attention implementation. On 2080ti 22g, I am fitting 1000+ more tokens on nous-capybara. It supports P100 when compiled and would probably help those with non FA cards vs having nothing. Haven't tried SDP yet but I'm guessing it did worse? I only did Q8 cache, maybe more will fit with Q4.
edit:
I was able to test both SDP and xformers but wasn't paying attention to the outputs, just OOM. For some reason I can't get the model coherent. It's probably due to having to reshape the tensors. Xformers was faster.