LeapLabTHU / FLatten-Transformer

Official repository of FLatten Transformer (ICCV2023)
377 stars 21 forks source link

Is "focused" really true? #4

Closed uestchjw closed 1 year ago

uestchjw commented 1 year ago

Thank you for your excellent work. I have a question about whether the feature map is truly sharper with the 'focused function' compared to the standard Transformer. From the perspective of the "pull" you mentioned, it seems to be more sharper and this is what figure 4 shows. However, I noticed the image of ball in Figure 3, its feature maps appear smoother than the original softmax attention. In a simple experiment, focused functions gave smoother feature map distributions versus standard softmax. This is confusing to me.

So I'm curious whether focused attention has a more centralized feature map than softmax attention in practical use. I appreciate any insights you can provide.

image

There is a little experiment. import torch torch.manual_seed(5) d = 4 n = 5 Q = torch.tensor([[1.,2.,4.,9.]]) # Qi K = torch.randint(1,9,(d,n)).to(torch.float)

att = torch.einsum('ik,kj->ij',Q,K) att_softmax = torch.softmax(att,dim=1) print(f'softmax-attention is: {att_softmax}')

def flatten(x): y = torch.pow(x,3) # x*3 return ytorch.norm(x)/torch.norm(y)

Q = flatten(Q) for i in range(n): K[:,i] = flatten(K[:,i]) att = torch.einsum('ik,kj->ij',Q,K) att = att/torch.sum(att) print(f'Flatten-attention is: {att}')

tian-qing001 commented 1 year ago

Hi @uestchjw, thank you very much for your thoughtful attention to our work. I'd like to provide you with some insights regarding your inquiry. Firstly, it's essential to note that while the focused function can enhance the focusing ability of vanilla's linear attention, we don't make an absolute claim that the focused function will invariably result in a more focused attention map than softmax attention. In Figure 3, particularly in the last two lines, we can observe that when compared to linear attention, the flatten approach can indeed yield significantly sharper attention weights. As per your code example, if we compare the weights of linear attention and flatten attention:

import torch
torch.manual_seed(5)
d = 4
n = 5
Q = torch.tensor([[1.,2.,4.,9.]]) # Qi
K = torch.randint(1,9,(d,n)).to(torch.float)

att = torch.einsum('ik,kj->ij',Q,K)
linear_att = att/torch.sum(att)
print(f'Linear-attention is: {linear_att}')

def flatten(x):
    y = torch.pow(x,3) # x**3
    return y*torch.norm(x)/torch.norm(y)

Q_ = flatten(Q)
for i in range(n):
    K[:,i] = flatten(K[:,i])

att = torch.einsum('ik,kj->ij',Q_,K)
att = att/torch.sum(att)
print(f'Flatten-attention is: {att}')

The output is:

Linear-attention is: tensor([[0.1521, 0.2809, 0.1881, 0.2165, 0.1624]])
Flatten-attention is: tensor([[0.0389, 0.4174, 0.1098, 0.3957, 0.0382]])

The results meet expectations.

Additionally, it's important to consider that in the example you provided, the average value of vector elements is approximately 5.0. In real-world models, the presence of normalization layers often means that the values of vector elements are not as large as in this simplified example. As a show case, here we take 0.2 for Q and K respectively:

import torch
torch.manual_seed(5)
d = 4
n = 5
Q = torch.tensor([[1.,2.,4.,9.]]) # Qi
K = torch.randint(1,9,(d,n)).to(torch.float)

Q = Q * 0.2
K = K * 0.2

att = torch.einsum('ik,kj->ij',Q,K)
softmax_att = att.softmax(dim=-1)
print(f'Softmax-attention is: {softmax_att}')

def flatten(x):
    y = torch.pow(x,3) # x**3
    return y*torch.norm(x)/torch.norm(y)

Q_ = flatten(Q)
for i in range(n):
    K[:,i] = flatten(K[:,i])

att = torch.einsum('ik,kj->ij',Q_,K)
att = att/torch.sum(att)
print(f'Flatten-attention is: {att}')

The output is:

Softmax-attention is: tensor([[0.0713, 0.5266, 0.1248, 0.1937, 0.0836]])
Flatten-attention is: tensor([[0.0389, 0.4174, 0.1098, 0.3957, 0.0382]])

It is evident that the attention weights of flatten and softmax attention are relatively close at this time.

Moreover, it's crucial to consider that the distribution of softmax attention weights might not necessarily represent an optimal solution. Extensive experimental evidence has demonstrated that our flatten model can consistently achieve better performance than softmax attention.

tian-qing001 commented 1 year ago

To add, usually softmax attention also uses scale=head_dim ** -0.5, which makes its weights even smoother.

uestchjw commented 1 year ago

Thank you very much for your explanation, it's helpful and useful. Firstly, I personally understand that if there is no depthwise convolution module (DWC), your method is something like 《effective attention: attention with linear complications》. It uses softmax, and you use focused function for Q and K. It turns out that your method works much better. This function must be very difficult to find. I think your work is very meaningful and enlightening for subsequent research. Secondly, I quite agree with you that 'the distribution of softmax attention weights might not necessarily represent an optimal solution'. And I think your experiment of "Focused linear attention at different stages" is very interesting. It tells us that in different stages, it may be necessary to focus on different parts because semantic features are constantly condensing. Focused linear attention tends to have similarity in a certain dimension or semantic space. Thirdly, I think DWC module is a little like local attention which may be also very useful. Global + Local maybe better? It's also very interesting. Finally, thank you again for your reply ~~