lucidrains / axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
MIT License
346 stars 30 forks source link

Positional embeddings for different image sizes #7

Closed PhilippMarquardt closed 3 years ago

PhilippMarquardt commented 3 years ago

Hi, once again thanks for your great work! Since I want to use the axial attention with positional embedding for unknown image sizes (But I know the max size), I was wondering if you think that changing https://github.com/lucidrains/axial-attention/blob/master/axial_attention/axial_attention.py#L104 to

for cnt, param in enumerate(self.params):
    x = x + param[([slice(None)] * (cnt + 2) + [slice(x.shape[cnt + 2])])]

does the right thing. I can now do this

v = AxialImageTransformer(64, depth = 1, axial_pos_emb_shape = (64,64), dim_index = 1)       
t1 = torch.randn(2, 64, 17, 16)
t2 = torch.randn(2, 64, 13, 18)
t3 = torch.randn(2, 64, 64, 64)
print(v(t1).shape)
print(v(t2).shape)
print(v(t3).shape)
Output:
torch.Size([2, 64, 17, 16])
torch.Size([2, 64, 13, 18])
torch.Size([2, 64, 64, 64])

I think that makes it easier to integrate it in fully convolutional nets for multi scale training.

lucidrains commented 3 years ago

@PhilippMarquardt Hey! So if you read the vision transformers paper, they actually try to have the positional embedding generalize to different sizes by interpolating them. It depends on what your goal is