lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.07k stars 143 forks source link

Causal linear attention benchmark #64

Closed caffeinetoomuch closed 3 years ago

caffeinetoomuch commented 3 years ago

First, thanks for this awesome repo!!

Based on T5 model classes from Huggingface's transformers, I was trying to use performer attention instead of original T5 attention. We finetuned t5-large with summarization model, and tried to profile both time and memory usage, and compare the performer attention with the original attention. I have only benchmarked with input size of 1024.

The result clearly showed that performer attention use lot less memory compared to the original transformer. I know from the paper that performer outperforms the original transformer when input size is bigger than 1024. However, finetuning and generation with the performer actually took longer, so I profiled the forward call of both the original T5 attention and the performer attention. The forward of T5 performer took twice longer and the main bottleneck was causal_dot_product_kernel from fast-transformers.

Is this a normal performace of the performer or causal attention calculation? or Will the performer attention be faster with the bigger input size?

lucidrains commented 3 years ago

@ice-americano ohh yes, so actually, for generation, I am not doing any caching. It should be way faster than T5

lucidrains commented 3 years ago

is that what you meant by it being slow?

caffeinetoomuch commented 3 years ago

Yes, that's what I was trying to ask. I just replaced softmax of T5Attention from huggingface transformers with FastAttention from this repo. However, both finetuning and generation were slower for FastAttention, even though it was clearly using less memory. Any idea on what I might be doing wrong?

Thanks!

lucidrains commented 3 years ago

If you are working at context lengths of less than 2048, training will be slower. The benefits of performers comes at 4096 and beyond

As for generation, it's because I never built the caching portion. It should be a lot faster

caffeinetoomuch commented 3 years ago

If we were to build caching, what will be cached? Projection matrices?

lucidrains commented 3 years ago

in linear attention, there's a two tensors that are accumulated over the sequence, so you would just need to cache those https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py#L168-L169

lucidrains commented 3 years ago

i'll get around to it this week!

lucidrains commented 3 years ago

@ice-americano the big problem for linear attention for pytorch is the fact that everyone relies on this CUDA kernel written by EPFL. i need to write my own in numba so i can have more control over the changes

caffeinetoomuch commented 3 years ago

What is EPFL? Also did you mean you are planning to rewrite causal_linear_attention in numba instead of using CausalDotProduct from fast_transformers.causal_product? What are the advantages of using code in numba? Will it be faster?

Thanks for all the responses!

lucidrains commented 3 years ago

@ice-americano its just so we can experiment more with linear attention https://developer.nvidia.com/cuda-python i doubt it can get any faster than what EPFL already wrote. the code is just too much to build upon

can you confirm the slow down is when you try to generate from an autoregressive performer? i can fix it if so

caffeinetoomuch commented 3 years ago

Actually installing from pip or building from source took a while, and that should have happend due to EPFL compilation(I have a shallow knowlodge on cuda kernel or library 😅).

We have fixed our code to use SelfAttention instead of FastAttention, and we might have been setting wrong parameters or etc, since now the performance and speed of the performer looks similar to wha the paper was specifying. So I think you can close this issue for now, and thanks for responsive feedbacks!

lucidrains commented 3 years ago

ok! i'll work on the other issue (fast generation) - glad to hear the original issue is resolved!

lh-gt commented 1 year ago

Actually installing from pip or building from source took a while, and that should have happend due to EPFL compilation(I have a shallow knowlodge on cuda kernel or library 😅).

We have fixed our code to use SelfAttention instead of FastAttention, and we might have been setting wrong parameters or etc, since now the performance and speed of the performer looks similar to wha the paper was specifying. So I think you can close this issue for now, and thanks for responsive feedbacks!

@ice-americano hi, i have met the same problem. i use SelfAttention in performer to replace bert self-attention, and eval is slower. could you share your config?