pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.24k stars 118 forks source link

[feat] int8 flash attention #952

Open felipemello1 opened 1 week ago

felipemello1 commented 1 week ago

hi all, I saw this tweet and thought of sharing it. The accuracy degration doesnt look too good, but maybe the speed makes it worth it?

https://x.com/papers_anon/status/1839131401322639805?s=46

To be clear: I am not requesting the feature, just mostly sharing it. Thanks! :)

jcaip commented 1 week ago

cc @cpuhrsch @HDCharles I think we could do this with flexattention? Flagging just so you are aware there's interest.

cpuhrsch commented 3 days ago

@jcaip - Worth a try. Essentially you'd need to dequant within the score mod (before the softmax) and the inputs will have to be quantized. I think at this point only query and key could be quantized, because values will need to be matmul'd against by the result of the softmax.