state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.74k stars 1.07k forks source link

thop fails to count the flops and parameters of the custom operator Mamba #303

Open niuzehai opened 5 months ago

niuzehai commented 5 months ago

I am using a custom operator Mamba (from mamba_ssm.modules.mamba_simple) in my project, but when I use the thop library to count the model parameters, it seems that thop does not count the parameters of Mamba.

My model definition is as follows:

from mamba_ssm.modules.mamba_simple import Mamba

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.mamba = Mamba(...)
        ...

    def forward(self, x):
        x = self.mamba(x)
        ...
        return x

When I use thop to count the parameters, I use the following code:

from thop import profile

model = MyModel()
input = torch.randn(...)
flops, params = profile(model, inputs=(input,))
print(f'Params: {params}')
print(f'FLOPS: {flops}')

However, the output params do not seem to include the parameters of Mamba. I suspect this might be because Mamba is a custom operator and thop cannot automatically count its parameters.

I would like to ask:

  1. How can I make thop correctly count the parameters of a model that includes custom operators? Do I need to write specific hooks or handlers for Mamba?
  2. Apart from thop, are there any other libraries or tools that can be used to count the parameters of a model that includes custom operators?
  3. If it is temporarily not possible to automatically count the parameters using tools, are there any suggested methods to manually estimate the parameter count of Mamba?

I would greatly appreciate any suggestions or solutions! Please let me know if more information is needed.

I hope this Issue description is clear and provides enough information. You can modify and supplement it according to your actual situation. When submitting the Issue, remember to choose appropriate labels (such as "question" or "help wanted"), and if your project has multiple maintainers, consider setting them as assignees.

If you have any other questions, feel free to ask!

tridao commented 5 months ago

Counting params should just be sum(p.numel() for p in model.parameters()). I'm not familiar with thop. I assume there's some way to specify how many flops a custom operation takes.

niuzehai commented 5 months ago

Thank you for your valuable input. You are correct that by carefully examining the operators used in the model, we can accurately count the number of parameters, and sum(p.numel() for p in model.parameters()) is a reliable way to calculate the total number of parameters, including those of custom operators like Mamba.

Regarding the calculation of FLOPs for Mamba, I noticed that MzeroMiko has provided a method to manually calculate Mamba's FLOPs in the linked issue (#110). This is indeed helpful for understanding the computational complexity of models involving Mamba.

However, I have a question about combining the manually calculated FLOPs of Mamba with the FLOPs calculated by thop for the rest of the model. Is it valid to simply add the FLOPs obtained from these two methods?

There are a few considerations:

  1. Overlapping computations: If there are any overlapping computations between Mamba and the native PyTorch operators, simply adding their FLOPs might lead to double-counting. We need to ensure that the manual calculation of Mamba's FLOPs only includes computations unique to Mamba.
  2. Interaction between operators: The computational complexity of a model is not always a simple sum of the complexities of its individual operators. There might be interactions or optimizations that happen when operators are used together, which could affect the total FLOPs. Simply adding the individually calculated FLOPs might not capture these interactions.
  3. Accuracy of manual calculation: We need to ensure that the manual calculation of Mamba's FLOPs, as provided in the linked issue, is accurate and complete. If it misses some computations or makes simplifying assumptions, the total FLOPs obtained by adding it to thop's results might not be precise.

Considering these points, while adding manually calculated FLOPs of Mamba to thop's results could serve as an approximation, it might not always give an exact measure of the model's total computational complexity.

I would be grateful for your thoughts on this matter. Is there a way to accurately calculate the total FLOPs of a model that includes Mamba, given the information we have from thop and the manual calculation method provided by MzeroMiko?

Thank you once again for your help in understanding the complexity of models with custom operators.

MzeroMiko commented 5 months ago

I am not familiar with thop, but I do familiar with fvcore, which may helps.

A code period from thop which may help you understand what I am saying below:

def counter_matmul(input_size, output_size):
    input_size = np.array(input_size)
    output_size = np.array(output_size)
    return np.prod(input_size) * output_size[-1]

In fact, there's no such thing as accurate FLOPs at all. Take a default supported operation in fvcore as an example:

    if equation == "abc,abd->acd":
        n, c, t = input_shapes[0]
        p = input_shapes[-1][-1]
        flop = n * c * t * p
        return flop

So here comes my solution, based on mamba issue #110, which is for VMamba:

def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: r(D N)
    B: r(B N L)
    C: r(B N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32

    ignores:
        [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 
    """
    assert not with_complex 
    # https://github.com/state-spaces/mamba/issues/110
    flops = 9 * B * L * D * N
    if with_D:
        flops += B * D * L
    if with_Z:
        flops += B * D * L    
    return flops

def selective_scan_flop_jit(inputs, outputs, flops_fn=flops_selective_scan_fn):
    print_jit_input_names(inputs)
    B, D, L = inputs[0].type().sizes()
    N = inputs[2].type().sizes()[1]
    flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
    return flops
        supported_ops={
            "prim::PythonOp.SelectiveScanMamba": partial(selective_scan_flop_jit, flops_fn=flops_selective_scan_fn),
        }

        Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)

Note that the code above do not fit any situations when using mamba modules, (I may support more cases in the future, but now it only support the situation as in VMamba). If you are using mamba, and do know exactly about what situation you are in, you can write another flops counting function and add it into fvcore like the code above.

lth456321 commented 5 months ago

I am not familiar with thop, but I do familiar with fvcore, which may helps.

A code period from thop which may help you understand what I am saying below:

def counter_matmul(input_size, output_size):
    input_size = np.array(input_size)
    output_size = np.array(output_size)
    return np.prod(input_size) * output_size[-1]

In fact, there's no such thing as accurate FLOPs at all. Take a default supported operation in fvcore as an example:

    if equation == "abc,abd->acd":
        n, c, t = input_shapes[0]
        p = input_shapes[-1][-1]
        flop = n * c * t * p
        return flop

So here comes my solution, based on mamba issue #110, which is for VMamba:

def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: r(D N)
    B: r(B N L)
    C: r(B N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32

    ignores:
        [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 
    """
    assert not with_complex 
    # https://github.com/state-spaces/mamba/issues/110
    flops = 9 * B * L * D * N
    if with_D:
        flops += B * D * L
    if with_Z:
        flops += B * D * L    
    return flops

def selective_scan_flop_jit(inputs, outputs, flops_fn=flops_selective_scan_fn):
    print_jit_input_names(inputs)
    B, D, L = inputs[0].type().sizes()
    N = inputs[2].type().sizes()[1]
    flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
    return flops
        supported_ops={
            "prim::PythonOp.SelectiveScanMamba": partial(selective_scan_flop_jit, flops_fn=flops_selective_scan_fn),
        }

        Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)

Note that the code above do not fit any situations when using mamba modules, (I may support more cases in the future, but now it only support the situation as in VMamba). If you are using mamba, and do know exactly about what situation you are in, you can write another flops counting function and add it into fvcore like the code above.

Do you mean calculate mamba's FLOPs with fvcore and add with 9 B L D N then I can get the mamba's FLOPs? thank you

MzeroMiko commented 5 months ago

@lth456321 Basically yes. But 9BLDN is just for the associative scan in forward, you need to pay attention to extra calculations like plus D or multiply z, and also pay attention to the dimensions you use in practise.

Sawyer117 commented 4 months ago

@MzeroMiko Thank you for the informative reply, I just wonder for "flops_selective_scan_fn" function, this should be the flops for 1 ssm block right? for the whole model I can just num_layers*flops_selective_scan_fn right? much appreciated

I am not familiar with thop, but I do familiar with fvcore, which may helps.

A code period from thop which may help you understand what I am saying below:

def counter_matmul(input_size, output_size):
    input_size = np.array(input_size)
    output_size = np.array(output_size)
    return np.prod(input_size) * output_size[-1]

In fact, there's no such thing as accurate FLOPs at all. Take a default supported operation in fvcore as an example:

    if equation == "abc,abd->acd":
        n, c, t = input_shapes[0]
        p = input_shapes[-1][-1]
        flop = n * c * t * p
        return flop

So here comes my solution, based on mamba issue #110, which is for VMamba:

def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: r(D N)
    B: r(B N L)
    C: r(B N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32

    ignores:
        [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] 
    """
    assert not with_complex 
    # https://github.com/state-spaces/mamba/issues/110
    flops = 9 * B * L * D * N
    if with_D:
        flops += B * D * L
    if with_Z:
        flops += B * D * L    
    return flops

def selective_scan_flop_jit(inputs, outputs, flops_fn=flops_selective_scan_fn):
    print_jit_input_names(inputs)
    B, D, L = inputs[0].type().sizes()
    N = inputs[2].type().sizes()[1]
    flops = flops_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False)
    return flops
        supported_ops={
            "prim::PythonOp.SelectiveScanMamba": partial(selective_scan_flop_jit, flops_fn=flops_selective_scan_fn),
        }

        Gflops, unsupported = flop_count(model=model, inputs=(input,), supported_ops=supported_ops)

Note that the code above do not fit any situations when using mamba modules, (I may support more cases in the future, but now it only support the situation as in VMamba). If you are using mamba, and do know exactly about what situation you are in, you can write another flops counting function and add it into fvcore like the code above.

MzeroMiko commented 3 months ago

@Sawyer117 Not really,this is only the flops for ssm, but for the whole mamba block,you may need to take all the einsum or linear which are used to prepare the parameters for ssm or to build the structure into account.