facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.58k stars 610 forks source link

Efficiency of local window attention with memory_efficient_attention #776

Open DianCh opened 1 year ago

DianCh commented 1 year ago

❓ Questions and Help

Hi! I work on vision transformers and implemented the window attention, where the feature map will be divided into local windows and self-attention among tokens will only happen inside each window. Let's say the batch of feature maps have tensor size [B, C, H, W] and I've divided and reshaped them into [4B, C, H/2, W/2] and eventually [4B, H/2 * W/2, C].

The thing is that I found that the memory consumption of this window attention is pretty much the same as the vanilla self-attention with token shape [B, H * W, C]. Is this expected? I thought reducing the sequence length would lower the memory consumption but not sure if packing the number of windows into the batch dimension effectively doesn't improve things.

Thanks a lot!

danthe3rd commented 1 year ago

Hi, With memory_efficient_attention, the memory used no longer scales quadratically with the sequence length (only linearly - mainly because we need to store the outputs anyway), so this behavior is expected