vasqu / mamba2-torch

MIT License
22 stars 1 forks source link

flops about mamba2 #3

Open dumpmemory opened 1 month ago

dumpmemory commented 1 month ago

hi:

thanks for your work. I am interested in the calculation about mamba2's flops for SSD part. My calculation for https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py is too high to believe. I need your help.

vasqu commented 1 month ago

The original mamba2 paper goes in detail through it: https://arxiv.org/pdf/2405.21060 --> section 6.3 ("Total Cost")

Training flops: O(TN^2) Inference flops: O(TN) where we assume _state_size == head_dim == chunklength (but it should be a good first estimation).

Otherwise, if you want exact calculations you would have to look into all exact computations (with their dimensions) which are elaborated in 6.3 for all the different blocks involved (center, left, right).

dumpmemory commented 1 month ago

yes, i want exact calculations without big O. I had use np.einsum_path for Listing 1. but the flops result is too high.

dumpmemory commented 1 month ago
def ssd_flops(T,Q,P,N):
    # center blocks
    #print(T,Q,P,N)
    center_blocks_sma_compute = T*Q*N+T*Q*Q+T*P*N
    #print("center_blocks_sma_compute",center_blocks_sma_compute/1e9,T*Q*N/1e9,T*Q*Q/1e9,T*P*N/1e9)
    #low-rank blocks right factors b terms
    b_compute = T*N*P
    #low-rank blocks right factors a terms
    a_compute = T*N*P/Q

    #low-rank blocks left factor c terms
    c_compute = T*P*N
    return center_blocks_sma_compute+b_compute+a_compute+c_compute

is this correct ?