ziplab / LITv2

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "Fast Vision Transformers with HiLo Attention"
Apache License 2.0
227 stars 11 forks source link

Question about training speed #14

Open tanbuzheng opened 8 months ago

tanbuzheng commented 8 months ago

Hi,I tried to use HiLo block on the image restoration task, but I found that the training speed of HiLo is much slower compared to Swin. Specifically, we mainly adopt HiLo block on 64x64 resolution and set the local window size to 8x8. What is the reason for this? Is there any way to increase training speed?

HubHop commented 8 months ago

Hi @tanbuzheng, thanks for using HiLo. Can you give more details or a brief script of your testing? e.g., tensor shape of your feature map, concrete setting of HiLo including dim, num_heads, alpha.

tanbuzheng commented 8 months ago

Thank you very much for your immediate reply! I used 8 stacked HiLo Transformer blocks on 64x64 resolutions,in which head_num=8, head_dim=32, alpha=0.5, local window sizes=8x8. And the other module settings are same. I guess it is the depth-wise Conv used in FFN that affects the running speed? I wonder if there are any practical reasons causing that HiLo is slower than Swin?

HubHop commented 8 months ago

Thanks for the addtional information. If we only look at the Attention layer itself, theoretically HiLo attention should be faster than Local Window Attention,

from hilo import HiLo

W = H = 64

# all heads for Hi-Fi, pure local window attention
model = HiLo(dim=256, num_heads=8, window_size=8, alpha=0.)
print(f'All for Hi-Fi, FLOPs = {round(model.flops(H, H)/1e9, 2)}G') 

# all heads for Lo-Fi
model = HiLo(dim=256, num_heads=8, window_size=8, alpha=1.)
print(f'All for Lo-Fi, FLOPs = {round(model.flops(H, H)/1e9, 2)}G') 

# all heads for Lo-Fi
model = HiLo(dim=256, num_heads=8, window_size=8, alpha=0.5)
print(f'Half-Half, FLOPs = {round(model.flops(H, H)/1e9, 2)}G') 

Output:

All for Hi-Fi, FLOPs = 1.21G
All for Lo-Fi, FLOPs = 0.68G
Half-Half, FLOPs = 0.81G

Therefore, the attention layer is not the issue. Compared to Swin Transformer blocks, the only difference in FFN is that LITv2 has one additional depthwise convolutional layer. You can debug your code by removing this layer to see if the speed can be improved.

tanbuzheng commented 8 months ago

Thanks a lot! I will try it later.