qwopqwop200 / Neighborhood-Attention-Transformer

NAT implementation(Neighborhood Attention Transformer) This is an unofficial implementation. https://arxiv.org/pdf/2204.07143.pdf
MIT License
7 stars 1 forks source link

About relative position encoding #1

Open luoyixi924208423 opened 2 years ago

luoyixi924208423 commented 2 years ago

Excuse me.I notice your work.It gives me a lot of encourage to explore more work about neighborhood-attention transformer.I read your python file,but have problem in relative position encoding.Can you explain the theory about it?Thank you very much.

qwopqwop200 commented 2 years ago

Transformer1 outputs the same output value regardless of the order of the inputs. For tasks such as NLP or Vision, the order of the inputs is important, so I used Absolute Position Encodings (APE)1 to indicate the position. However, APE cannot learn the relative position of the output value and the input value. And RPE (Relative Postional Encoding) 2, which improved this shortcoming, improved performance when applied. And as a result of applying RPE in Swin Transformer 4, one of VIT (Vision Transformer) 3, performance improvement was also shown in computer vision. Neighborhood Attention Transformer5 also applies RPE like Swin Transformer, but it is implemented in a different format from Swin Transformer due to the characteristics of Neighborhood Attention.

RPE is added after getting attn_weight

B = batch size H = height of input image W = width of input image C = Channel

Window attention (Swin Transformer) splits the image into 7x7 blocks when the kernel size is 7x7, and when executing the attention mechanism, q=(B,49,C), k = (B,C,49) ,v= (B,49,C),Since we have a 7x7 kernel, it has a value of 49.attn_weight can be obtained through matrix multiplication of q,k and has a size of attn_weight = (B,49,49).

But Neighborhood attention not splits the image when the kernel size is 7x7, and when executing the attention mechanism,q=(B,H,W,C), k = (B,H,W,C,49) ,v=(B,H,W,49,C),Since we have a 7x7 kernel, it has a value of 49.attn_weight can be obtained through matrix multiplication of q,k and has a size of attn_weight = (B,H,W,49).

Because the size of attn_weight(B,49,49) of window attention and the size of attn_weight(B,H,W,49) of Neighborhood Attention are different, they are implemented in different formats.

For an easier and more detailed explanation of RPE, see this URL:https://medium.com/@_init_/how-self-attention-with-relative-position-representations-works-28173b8c245a

If you run the following Python code, you can intuitively see the difference between the two.


import torch
import torch.nn as nn

window_size = 3 #attention kernel size
H = 5 #height of input image
W = 5 #width of input image

def swin_rpb(window_size): #swin transformer
    coords_h = torch.arange(window_size)
    coords_w = torch.arange(window_size)
    coords = torch.stack(torch.meshgrid([coords_h, coords_w]))  
    coords_flatten = torch.flatten(coords, 1)
    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  
    relative_coords = relative_coords.permute(1, 2, 0).contiguous() 
    relative_coords[:, :, 0] += window_size - 1 
    relative_coords[:, :, 1] += window_size - 1
    relative_coords[:, :, 0] *= 2 * window_size - 1
    relative_coords = relative_coords.sum(-1)
    return relative_coords

def nat_rpb(window_size,H,W): #neighborhood attention transformer
    idx_h = torch.arange(0,window_size)
    idx_w = torch.arange(0,window_size)
    idx_k = ((idx_h.unsqueeze(-1) * (2*window_size-1)) + idx_w).view(-1)
    num_repeat_h = torch.ones(window_size,dtype=torch.long)
    num_repeat_w = torch.ones(window_size,dtype=torch.long)
    num_repeat_h[window_size//2] = H-(window_size-1)
    num_repeat_w[window_size//2] = W-(window_size-1)
    bias_hw = (idx_h.repeat_interleave(num_repeat_h).unsqueeze(-1) * (2*window_size-1)) + idx_w.repeat_interleave(num_repeat_w)
    bias_idx = bias_hw.unsqueeze(-1) + idx_k
    return bias_idx.view(H,W,window_size*window_size).permute(2,0,1)

relative_bias = nn.Parameter(torch.rand((2*window_size-1)**2))

swin_rpb_idx = swin_rpb(window_size)
nat_rpb_idx = nat_rpb(window_size,H,W)

print('swin_rpb_idx')
print(swin_rpb_idx)
print('\n')
print('Results of RPE of swin transformer')
print(relative_bias[swin_rpb_idx])
print('\n')
print('nat_rpb_idx')
print(nat_rpb_idx)
print('Results of RPE of neighborhood attention transformer')
print(relative_bias[nat_rpb_idx])