yoxu515 / aot-benchmark

An efficient modular implementation of Associating Objects with Transformers for Video Object Segmentation in PyTorch
BSD 3-Clause "New" or "Revised" License
599 stars 108 forks source link

About the implementation of multi-head attention in DeAOT #28

Closed MUVGuan closed 1 year ago

MUVGuan commented 1 year ago

Hello, I have a question after reading your great work DeAOT. When you conduct the ablation study about head number, you compare the multi-head and single-head in DeAOT. As we all know, the common implementation of multi-head is to reshape Query (its shape is HW×batch_size×C, just take Query as an example), and its channel dimension C is divided into C/num_head, then Query is reshaped to HW×batch_size×num_head×(C/num_head). This kind of implementation can keep the computation complexity as single-head has. But the ablation study about head number shows that multi-head significantly reduces speed. So I want to know that what kind of implementation of multi-head attention in DeAOT? Is it what I show above?

z-x-yang commented 1 year ago

It's just the way as you described. But this implementation cannot keep run-time speed due to some processes influenced by the head number, such as Softmax and einsum/bmm.

MUVGuan commented 1 year ago

Thanks for your reply!

MUVGuan commented 1 year ago

I'm sorry to bother you again. When I compute the computational complexity of long-term multihead attention, I find that it is O(TH²W²C)+O(NTH²W²), but in DeAOT it is O(NTH²W²) and it is mainly influenced by Corr(Q, K). I think that Corr(Q, K) consists of two processes: bmm between Q and K, softmax. Count the number of multiplications: bmm between Q and K: HW THW C/N N. So the computational complexity of bmm is O(TH²W²C). softmax: N HW *THW. So the computational complexity of softmax is O(NTH²W²). Thus the computational complexity of Corr(Q, K) is O(TH²W²C)+O(NTH²W²). I don't know why it is O(NTH²W²) in DeAOT, could you please explain or point out the mistake that I have? Thanks a lot!

z-x-yang commented 1 year ago

The computational complexity of bmm should be O(TH²W²C) + O(NTH²W²) because the output tensor shape of Corr(Q, K) is NTH²W². Even if we ignore the computation for collecting tensor, the run-time between N=1 and N=8 can also be different because they may require different numbers of CUDA cores to compute.

I didn't say that the attention is O(NTH²W²) in DeAOT. You could try to decrease the head number in AOT or other common transformers, and you will find excellent acceleration.

MUVGuan commented 1 year ago

Thank you for pointing out my mistake about bmm! But I can not understand the sentence in section 4.2 of DeAOT: Concretely, the computational complexity of long-term attention is O(NTH²W²), which is proportional to the head number N since each head contains a correlation function Corr(Q, K).

图片

As you said before, the computational complexity of bmm is O(TH²W²C) + O(NTH²W²), and the computational complexity of softmax is O(NTH²W²), so the computational complexity of Corr(Q, K) is O(TH²W²C) + O(NTH²W²) + O(NTH²W²). What is the difference between the O(NTH²W²) in the sentence above and the computational complexity of Corr(Q, K)? Why you write O(NTH²W²) only in the sentence?

I'm really confused. Thank you for your answer with patience.

z-x-yang commented 1 year ago

Oh. You are talking about our paper. I misunderstanded your meaning.

In the paper, we describe our DeAOT method in a single-head manner. So we use O(NTH²W²) to describe AOT by contrast.

According to our results, O(NTH²W²) is the major component of the complexity of the long-term attention when we keep the total channel number (C×N) consistent. The influence of different channel number (C) is not as significant as head number (N).

z-x-yang commented 1 year ago

By the way, the channel number of Q and K is 128 in DeAOT, while the number of V is 256. When multiple heads are used, the channels will be divided accordingly.

MUVGuan commented 1 year ago

Thank you very much for your patient and detailed answer, now I understand.