Open luoyixi924208423 opened 2 years ago
Transformer1 outputs the same output value regardless of the order of the inputs. For tasks such as NLP or Vision, the order of the inputs is important, so I used Absolute Position Encodings (APE)1 to indicate the position. However, APE cannot learn the relative position of the output value and the input value. And RPE (Relative Postional Encoding) 2, which improved this shortcoming, improved performance when applied. And as a result of applying RPE in Swin Transformer 4, one of VIT (Vision Transformer) 3, performance improvement was also shown in computer vision. Neighborhood Attention Transformer5 also applies RPE like Swin Transformer, but it is implemented in a different format from Swin Transformer due to the characteristics of Neighborhood Attention.
RPE is added after getting attn_weight
B = batch size H = height of input image W = width of input image C = Channel
Window attention (Swin Transformer) splits the image into 7x7 blocks when the kernel size is 7x7, and when executing the attention mechanism, q=(B,49,C), k = (B,C,49) ,v= (B,49,C),Since we have a 7x7 kernel, it has a value of 49.attn_weight can be obtained through matrix multiplication of q,k and has a size of attn_weight = (B,49,49).
But Neighborhood attention not splits the image when the kernel size is 7x7, and when executing the attention mechanism,q=(B,H,W,C), k = (B,H,W,C,49) ,v=(B,H,W,49,C),Since we have a 7x7 kernel, it has a value of 49.attn_weight can be obtained through matrix multiplication of q,k and has a size of attn_weight = (B,H,W,49).
Because the size of attn_weight(B,49,49) of window attention and the size of attn_weight(B,H,W,49) of Neighborhood Attention are different, they are implemented in different formats.
If you run the following Python code, you can intuitively see the difference between the two.
import torch
import torch.nn as nn
window_size = 3 #attention kernel size
H = 5 #height of input image
W = 5 #width of input image
def swin_rpb(window_size): #swin transformer
coords_h = torch.arange(window_size)
coords_w = torch.arange(window_size)
coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size - 1
relative_coords[:, :, 1] += window_size - 1
relative_coords[:, :, 0] *= 2 * window_size - 1
relative_coords = relative_coords.sum(-1)
return relative_coords
def nat_rpb(window_size,H,W): #neighborhood attention transformer
idx_h = torch.arange(0,window_size)
idx_w = torch.arange(0,window_size)
idx_k = ((idx_h.unsqueeze(-1) * (2*window_size-1)) + idx_w).view(-1)
num_repeat_h = torch.ones(window_size,dtype=torch.long)
num_repeat_w = torch.ones(window_size,dtype=torch.long)
num_repeat_h[window_size//2] = H-(window_size-1)
num_repeat_w[window_size//2] = W-(window_size-1)
bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*window_size-1)) + idx_w.repeat_interleave(num_repeat_w)
bias_idx = bias_hw.unsqueeze(-1) + idx_k
return bias_idx.view(H,W,window_size*window_size).permute(2,0,1)
relative_bias = nn.Parameter(torch.rand((2*window_size-1)**2))
swin_rpb_idx = swin_rpb(window_size)
nat_rpb_idx = nat_rpb(window_size,H,W)
print('swin_rpb_idx')
print(swin_rpb_idx)
print('\n')
print('Results of RPE of swin transformer')
print(relative_bias[swin_rpb_idx])
print('\n')
print('nat_rpb_idx')
print(nat_rpb_idx)
print('Results of RPE of neighborhood attention transformer')
print(relative_bias[nat_rpb_idx])
Excuse me.I notice your work.It gives me a lot of encourage to explore more work about neighborhood-attention transformer.I read your python file,but have problem in relative position encoding.Can you explain the theory about it?Thank you very much.