Open bilzard opened 3 months ago
I might be wrong, but I believe it is due to the causal mask. Without the causal mask we can do (QK^T)V = Q (K^TV) dropping the complexity to O(n x d x d). With the causal mask this cannot be achieved. There are other papers proposing chunkwise parallel algorithms that implement this setup (linearized attention with causal mask) efficiently on a CUDA level with complexity O(L x C x d) where C is the chunk size, but the authors did not seem to implement that.
@juankost
I might be wrong, but I believe it is due to the causal mask. Without the causal mask we can do (QK^T)V = Q (K^TV) dropping the complexity to O(n x d x d). With the causal mask this cannot be achieved.
Thanks for pointing this out. Surely it seems not straight forward to apply causal mask to Q x (K x V) style architecture. Hmm... if they doesn't use architecture in the original paper, I wonder how the authors reduces training time.
There are other papers proposing chunkwise parallel algorithms that implement this setup (linearized attention with causal mask) efficiently on a CUDA level with complexity O(L x C x d) where C is the chunk size.
Interesting. Could you share the paper title of this?
@bilzard
My understanding of their training time comparison is that they compared the total pretraining time of Pythia with the time it takes to finetune Pythia to DijJiang attention (since they mention that they use existing pythia checkpoints for their Dijiang finetuning). i.e. with the claim that with little extra finetuning you get a linearized attention transformer with comparable performance to original transformer.
I would love to hear from the authors on this point, if my understanding is correct.
Regarding the papers, as far as I know there are two papers that introduced this, since they came out roughly at similar time. Fortunately they provide efficient code implementations so you could pretty easily adapt DiJiang to use the efficient code from these papers (after the DCT projection, you could pass the modified, Q, K, and V to the flash linear attention code to compute this masked attention efficiently)
Yang, S., Wang, B., Shen, Y., Panda, R. and Kim, Y., 2023. Gated linear attention transformers with hardware-efficient training. arXiv preprint arXiv:2312.06635.
Qin, Z., Sun, W., Li, D., Shen, X., Sun, W. and Zhong, Y., 2024. Lightning attention-2: A free lunch for handling unlimited sequence lengths in large language models. arXiv preprint arXiv:2401.04658.
@juankost
My understanding of their training time comparison is that they compared the total pretraining time of Pythia with the time it takes to finetune Pythia to DijJiang attention (since they mention that they use existing pythia checkpoints for their Dijiang finetuning).
It seems make sense.
Thanks to sharing papers. I also found efficient implementation of causal matrix product by third party[1], which was used in Performer[2], so share it for reference.
@bilzard
Thanks for the link. I am aware of this paper. I do not recommend using their CUDA code, since it uses the cumsum operation which is still sequential --> that's one of the motivations for introducing the chunkwise parallel implementation in the papers I linked (I recommend reading those papers, they discuss the paper you linked)
@juankost I see. Thanks!
Provided code calculates matrix product of q and k. https://github.com/YuchuanTian/DiJiang/blob/main/modeling/pythia-2.8B-dijiang/modeling_gpt_neox_dijiang.py#L286
That means it has computational complexity O(n x n x d). Is this different implementation than used in the original paper?