hunto / LocalMamba

Code for paper LocalMamba: Visual State Space Model with Windowed Selective Scan
Apache License 2.0
209 stars 12 forks source link

Regarding FLOP calculation in Local ViM #36

Open saarthakk-insitro opened 1 month ago

saarthakk-insitro commented 1 month ago

Hi Authors,

Thanks for making the code available, it was super helpful!!

I was trying to understand the FLOPs calculation and came across that you have commented "flops += 9 L D N + 2 D * L" and instead used different formula:

https://github.com/hunto/LocalMamba/blob/main/classification/lib/models/local_vim.py#L442

for i in range(len(layer.mixer.multi_scan.choices)):

flops += 9 L D N + 2 D * L

# A
flops += D * L * N
# B
flops += D * L * N * 2
# C
flops += (D * N + D * N) * L
# D
flops += D * L
# Z
flops += D * L

Can you please help me understand the reasoning behind this change in FLOP calculation formula. With current formula, in ViM since we have 2 scan choices, the code gives 5.1 GFLOPs consistent with reported flops in your work. However according to authors of Mamba, they posted that selective scan should take 9LD*N as consistent with your commented part (https://github.com/state-spaces/mamba/issues/110).

Thanks

saarthakk-insitro commented 1 month ago

Also I believe that here https://github.com/hunto/LocalMamba/blob/main/classification/lib/models/local_vim.py#L424:

In line 424, 426, and 428, the flops needs to be multiplied by 2 when calculating flops for ViM since it has 2 forward-backward layers having different discretization and causal conv1d steps. I just calculate with updated code and got 5.91 GFLOPs. Can you please confirm if this analysis is correct?