NVlabs / MambaVision

Official PyTorch Implementation of MambaVision: A Hybrid Mamba-Transformer Vision Backbone
https://arxiv.org/abs/2407.08083
Other
675 stars 37 forks source link

Flops calculation #21

Closed AndyCao1125 closed 1 month ago

AndyCao1125 commented 1 month ago

Dear authors,

Thanks for your amazing work!

I have a small question, is it possible to provide the code of FLOPs calculation or explain how to compute the FLOPs of MambaVision. Thanks!

ahatamiz commented 1 month ago

Hi @AndyCao1125

You can use ptflops for this purpose and install it using pip install ptflops .Following snippet should be useful for FLOPs/Params calculations:

from ptflops import get_model_complexity_in
from mamba_vision import *

resolution = 224
model = mamba_vision_T().cuda().eval()

Flops, params = get_model_complexity_info(model, tuple([3, resolution, resolution]), as_strings=False, print_per_layer_stat=False, verbose=False)

print(f"Model stats: GFlops: {Flops*1e-9}, and (M) params: {params*1e-6}")

Hope it helps

AndyCao1125 commented 1 month ago

Dear authors,

Thanks for your prompt reply! I've sucessfully tested the Flops and Params of the Mamba models with no error. However, I have a few more questions:

As discussed in (https://github.com/state-spaces/mamba/issues/303) and (https://github.com/state-spaces/mamba/issues/110) in the original mamba github issues, I found that the Flops of SSM usually cannot be computed with ordinary libraries (e.g., thop) due to the special selective scan mechanism. The author Albert Gu gave a theoretical flops of $9LN$ for selective scan (https://github.com/state-spaces/mamba/issues/110#issuecomment-1919470069), and then MzeroMiko provided the code for fvcore-based computation manually and applied it into its work VMamba for Flops calculation (https://github.com/MzeroMiko/VMamba/blob/main/analyze/flops.py).

I noticed that although the ptflops library you provided doesn't report errors when calculating SSM, I would like to ask if this calculation actually counts the number of special operations inside the selective scan? Many thanks!

ahatamiz commented 1 month ago

Hi @AndyCao1125

We did not use Albert's estimated formula. But it should not change the reported numbers using ptflops as the SSM part is not the most compute intensive operation, even in stage 3 and 4 where it is still dominated by self-attention layers.

I assume L and N in 9LN formula denote the sequence length and state dimension size which are 196 and 16 respectively for stage 3 and 49 and 16 for stage 4.

In mamba_vision_B for example, we have 5 and 3 SSM layers in stage 3 and 4. So, assuming ptflops still does not account for the SSM part, the total (added Flops) should be:

(5 * 196 * 16 + 3 * 49 * 16 ) * 9e-9 = 0.000162288 GFLOPs

We have reported the number of FLOPs to be 15.0 GFLOPs for mamba_vision_B. So as a result, this still does not change the reported values.

Please feel free to let us know if the above calculations need to be modified to take into account anything that we did not consider.

Ali

AndyCao1125 commented 1 month ago

Hi @ahatamiz

Thanks! Sorry for the omitted part of the statement I just made. The computation of a complete selective scan should be equal to $9BLDN$ (https://github.com/state-spaces/mamba/issues/110#issuecomment-1972377311), where $B$ denotes batch, $D$ denotes d_model, $N$ denotes d_state (=16 as default in the Mamba setting), and $L$ denotes the sequence length. If the D matrix is used as well as the Z matrix (the optional input in the selective scan kernel). Please see: https://github.com/NVlabs/MambaVision/blob/6fe3cdc6265a724d5a7bda31cfbe304a2d69cc46/mambavision/models/mamba_vision.py#L400 and https://github.com/NVlabs/MambaVision/blob/6fe3cdc6265a724d5a7bda31cfbe304a2d69cc46/mambavision/models/mamba_vision.py#L401 , the additional flops= $BLD$ will be generated.

Thus, the revised flops should be: (5 * 196 * 16 * 16 + 3 * 49 * 16 * 16 ) * 9 * 1e-9 + (additional flops from B and Z) >= 0.00259 GFLOPs

Although the overall GFLOPs for mamba_vision_B are unchanged since we keep only a few significant digits, there may be an effect on the integer bits of GFLOPs for mamba_vision_L/L2. Moreover, for larger images (e.g., super-resolution datasets), this effect becomes more significant.

Therefore, I think that accurately calculating the flops of a mamba-based model is a relatively troublesome process due to the fact that the selective scan method needs to be calculated manually. In addition to the selective scan operation, Mamba module itself contains linear or conv1d modules that need to be added to the calculation of flops.

Could we try to realize the accurate calculation of flops for mamba-based models? (If needed, I'm willing to contribute with u :)

ahatamiz commented 1 month ago

Thanks @AndyCao1125 for clarification. I agree that calculating FLOPs for mamba-based are quite tricky. But since we are reporting GFLOPs with a few significant digits, the extra added value won't change things much at least in our configuration. If the sequence length increases by 100 times, the added FLOPs is still 0.2 GFLOPsdue to SSM part in the above example.

But other modules such as linear and conv1d are already taken into account -- imagine a hook-based implementation which registers all such layers.

AndyCao1125 commented 1 month ago

Thanks for your reply!