Closed yxchng closed 2 years ago
Seems like it expects channels_last format. Will you support channels_first? Frequent permutation is quite a hit on performance.
Thank you for your interest.
That's correct, the module expects channels-last inputs, similar to many other Attention implementations (although those are typically flattened across spatial axes and are 3D tensors).
Since the model is not a CNN, and Convolutions are the only operations in the model that require a channels first structure, it makes sense to keep inputs channels-last to avoid frequent permutations.
Some permutations and reshaping operations are unavoidable because of the multi-head split. The current structure is optimized for speed, as all linear projections expect channels-last, and our NA kernel expects inputs to be of shape Batch x Heads x Height x Width x Dim, which is again the most efficient way of computing outputs.
To be clear, channels-last is usually the more efficient format, hence the movement towards channels-last integration since torch 1.11. Although other factors also affect how much faster your operations get when switching to channels-last (to name a few: cuDNN kernels being called instead of naive ATEN kernels, architecture-specific kernels from NVIDIA packages).
But thank you for bringing this to our attention, shape requirements should be included in our documentation. We will update this in future versions.
Closing this due to inactivity. If you still have questions feel free to open it back up.
My input shape is
torch.Size([1, 256, 14, 14])
. Why am I getting this error?