Closed 924973292 closed 1 month ago
Thanks for your quick reply!!!
out_weight_shape = out_proj_weight.type().sizes()
assert out_proj_weight[1] == Dim
flops += Batch * Dim * L * out_proj_weight[0]
return flops
error here?
correct:
out_weight_shape = out_proj_weight.type().sizes()
assert out_weight_shape[1] == Dim
flops += Batch * Dim * L * out_weight_shape[0]
return flops
?
Besides, I update the mamba package, the following code runs ok for me now:
def MambaInnerFn_jit(inputs, outputs):
"""
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight) # (bl d)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
B = x_dbl[:, delta_rank:delta_rank + d_state]
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
C = x_dbl[:, -d_state:]
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
"""
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, A, D, delta_bias = inputs[:]
Batch, _, L = xz.type().sizes()
CWidth = conv1d_weight.type().sizes()[-1]
H = A.type().sizes()[-1] # 16
Dim, R = delta_proj_weight.type().sizes()
assert tuple(xz.type().sizes()) == (Batch, 2 * Dim, L)
assert tuple(conv1d_weight.type().sizes()) == (Dim, 1, CWidth)
assert tuple(x_proj_weight.type().sizes()) == (R + H + H, Dim)
assert tuple(A.type().sizes()) == (Dim, H)
with_Z = True
with_D = False
if "D" in inputs[7].debugName():
assert tuple(inputs[7].type().sizes()) == (Dim,)
with_D = True
flops = 0
flops += Batch * (Dim * L) * CWidth # causal_conv1d_cuda.causal_conv1d_fwd
flops += Batch * (Dim * L) * (R + H + H) # x_dbl = F.linear(...
flops += Batch * (Dim * R) * (L) # delta_proj_weight @ x_dbl[:, :delta_rank]
# https://github.com/state-spaces/mamba/issues/110
flops = 9 * Batch * L * Dim * H
if with_D:
flops += Batch * Dim * L
if with_Z:
flops += Batch * Dim * L
out_weight_shape = out_proj_weight.type().sizes()
assert out_weight_shape[1] == Dim
flops += Batch * Dim * L * out_weight_shape[0]
return flops
Oh, thank you for correcting this.
I did not check the code of the MambaInnerFn_jit
, since it has not been used by any (in VMamba or Vim or S4ND). But I did use the code MambaInnerFnNoOutProj_jit
to get the flops of Vim
, and the only difference between MambaInnerFnNoOutProj
and MambaInnerFn
is the final projection layer. So the corrected version should works.
I do have implemented analyze tools for vim actually, check this link: https://github.com/MzeroMiko/VMamba/blob/8c9cd412b7e76c62d55497958f7f839403bff74b/analyze/analyze_for_vim.py#L421