lucidrains / axial-attention

Implementation of Axial attention - attending to multi-dimensional data efficiently
MIT License
352 stars 30 forks source link

Hi, I have a problem #10

Open meiguoofa opened 3 years ago

meiguoofa commented 3 years ago

import torch from axial_attention import AxialAttention

img = torch.randn(1, 3, 256, 256)

attn = AxialAttention( dim = 3, # embedding dimension dim_index = 1, # where is the embedding dimension dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied heads = 1, # number of heads for multi-head attention num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more) sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true )

attn(img) # (1, 3, 256, 256)

Thanks for your great project, I want to ask if my image is one channel image will influence the num_dimensions value?