Hi! I work on vision transformers and implemented the window attention, where the feature map will be divided into local windows and self-attention among tokens will only happen inside each window. Let's say the batch of feature maps have tensor size [B, C, H, W] and I've divided and reshaped them into [4B, C, H/2, W/2] and eventually [4B, H/2 * W/2, C].
The thing is that I found that the memory consumption of this window attention is pretty much the same as the vanilla self-attention with token shape [B, H * W, C]. Is this expected? I thought reducing the sequence length would lower the memory consumption but not sure if packing the number of windows into the batch dimension effectively doesn't improve things.
Hi,
With memory_efficient_attention, the memory used no longer scales quadratically with the sequence length (only linearly - mainly because we need to store the outputs anyway), so this behavior is expected
❓ Questions and Help
Hi! I work on vision transformers and implemented the window attention, where the feature map will be divided into local windows and self-attention among tokens will only happen inside each window. Let's say the batch of feature maps have tensor size
[B, C, H, W]
and I've divided and reshaped them into[4B, C, H/2, W/2]
and eventually[4B, H/2 * W/2, C]
.The thing is that I found that the memory consumption of this window attention is pretty much the same as the vanilla self-attention with token shape
[B, H * W, C]
. Is this expected? I thought reducing the sequence length would lower the memory consumption but not sure if packing the number of windows into the batch dimension effectively doesn't improve things.Thanks a lot!