import torch
from axial_attention import AxialAttention
img = torch.randn(1, 3, 256, 256)
attn = AxialAttention(
dim = 3, # embedding dimension
dim_index = 1, # where is the embedding dimension
dim_heads = 32, # dimension of each head. defaults to dim // heads if not supplied
heads = 1, # number of heads for multi-head attention
num_dimensions = 2, # number of axial dimensions (images is 2, video is 3, or more)
sum_axial_out = True # whether to sum the contributions of attention on each axis, or to run the input through them sequentially. defaults to true
)
attn(img) # (1, 3, 256, 256)
Thanks for your great project, I want to ask if my image is one channel image will influence the num_dimensions value?
Thanks for your great project, I want to ask if my image is one channel image will influence the num_dimensions value?