Open dumpmemory opened 1 month ago
The original mamba2 paper goes in detail through it: https://arxiv.org/pdf/2405.21060 --> section 6.3 ("Total Cost")
Training flops: O(TN^2) Inference flops: O(TN) where we assume _state_size == head_dim == chunklength (but it should be a good first estimation).
Otherwise, if you want exact calculations you would have to look into all exact computations (with their dimensions) which are elaborated in 6.3 for all the different blocks involved (center, left, right).
yes, i want exact calculations without big O. I had use np.einsum_path for Listing 1. but the flops result is too high.
def ssd_flops(T,Q,P,N):
# center blocks
#print(T,Q,P,N)
center_blocks_sma_compute = T*Q*N+T*Q*Q+T*P*N
#print("center_blocks_sma_compute",center_blocks_sma_compute/1e9,T*Q*N/1e9,T*Q*Q/1e9,T*P*N/1e9)
#low-rank blocks right factors b terms
b_compute = T*N*P
#low-rank blocks right factors a terms
a_compute = T*N*P/Q
#low-rank blocks left factor c terms
c_compute = T*P*N
return center_blocks_sma_compute+b_compute+a_compute+c_compute
is this correct ?
hi:
thanks for your work. I am interested in the calculation about mamba2's flops for SSD part. My calculation for https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py is too high to believe. I need your help.