pytorch-labs / gpt-fast

Simple and efficient pytorch-native transformer text generation in <1000 LOC of python.
BSD 3-Clause "New" or "Revised" License
5.37k stars 488 forks source link

Support of FlashDecoding #188

Open jianc99 opened 1 week ago

jianc99 commented 1 week ago

Hi, I am trying to use gpt-fast for the inference of input batch size larger than 1. And I believe in this case, flash decoding might be a better choice for the attention kernel rather than the torch.compile generated one. As flash decoding has been released almost one year, is there any methods to integrate flash decoding with gpt-fast or torch.compile now? Thanks! @Chillee

Chillee commented 1 week ago

@jianc99 The most straightforward way today is to integrate flashdecoding as a custom op. However, stay tuned for something soon that will also makes this straightforward :)

jianc99 commented 1 week ago

Wow! I am really looking forward to that! As I think almost all the other parts of gpt-fast is great except the attention kernel when dealing with large batch size and long context length.

And for the custom op method, I did see this pr https://github.com/pytorch/pytorch/issues/120441#issue-2150006143

But I am not sure if they finally solved the problem and successfully integrated flash decoding. Do you have any more information about that? @Chillee Thanks!

Chillee commented 1 week ago

@jianc99 Yes, that issue should be solved now.