ant-research / Pyraformer

Apache License 2.0
252 stars 38 forks source link

Why the TVM impelmentation is memroy efficient #16

Open jlidw opened 1 year ago

jlidw commented 1 year ago

Thanks for your excellent work!

Just want to discuss the memory reduction problem. It seems that the TVM implementation does not store fewer matrices (like Queries, Keys, and Values matrix). The num of Q-K pairs is less than the full attention so that we can get a faster calculation speed, but why the memory reduction has a similar trend with the time reduction? Seems the TVM kernel does not use any technique to save the memory, and the padding 0 values are also int32, but the fact is that TVM implementation is memory efficient...

Looking forward to your reply.

Zhazhan commented 1 year ago

Hello, thanks for your interest in our work.

In fact, the number of Q-K pairs not only corresponds to the computational complexity, but also contributes to the memory consumption. The memory occupied by the attention scores matrix $S$ (S=Q@K, where @ is the matrix multiplication) grows quadratically with the sequence length $L$. Therefore, reducing the number of attention scores that need to be stored leads to reduced memory consumption. The TVM implementation stores a maximum of $A+C+1$ attention scores per query, thus reducing memory consumption.