MzeroMiko / VMamba

VMamba: Visual State Space Models,code is based on mamba
MIT License
2.17k stars 135 forks source link

Hi, I wonder the operator of prim::PythonOp.MambaInnerFn is the same as selective_scan_flop_jit? #299

Closed 924973292 closed 1 month ago

MzeroMiko commented 1 month ago

I do have implemented analyze tools for vim actually, check this link: https://github.com/MzeroMiko/VMamba/blob/8c9cd412b7e76c62d55497958f7f839403bff74b/analyze/analyze_for_vim.py#L421

924973292 commented 1 month ago

Thanks for your quick reply!!!

924973292 commented 1 month ago
        out_weight_shape = out_proj_weight.type().sizes()
        assert out_proj_weight[1] == Dim
        flops += Batch * Dim * L * out_proj_weight[0]

        return flops

error here?

correct:

out_weight_shape = out_proj_weight.type().sizes()
assert out_weight_shape[1] == Dim
flops += Batch * Dim * L * out_weight_shape[0]

return flops
?

Besides, I update the mamba package, the following code runs ok for me now:

def MambaInnerFn_jit(inputs, outputs):
"""
conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
x, z = xz.chunk(2, dim=1)
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
B = x_dbl[:, delta_rank:delta_rank + d_state]
B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
C = x_dbl[:, -d_state:]
C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
out, scan_intermediates, out_z = selective_scan_cuda.fwd(
    conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
)
F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
"""
xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight, out_proj_weight, A, D, delta_bias = inputs[:]
Batch, _, L = xz.type().sizes()
CWidth = conv1d_weight.type().sizes()[-1]
H = A.type().sizes()[-1]  # 16
Dim, R = delta_proj_weight.type().sizes()
assert tuple(xz.type().sizes()) == (Batch, 2 * Dim, L)
assert tuple(conv1d_weight.type().sizes()) == (Dim, 1, CWidth)
assert tuple(x_proj_weight.type().sizes()) == (R + H + H, Dim)
assert tuple(A.type().sizes()) == (Dim, H)

with_Z = True
with_D = False
if "D" in inputs[7].debugName():
    assert tuple(inputs[7].type().sizes()) == (Dim,)
    with_D = True

flops = 0
flops += Batch * (Dim * L) * CWidth  # causal_conv1d_cuda.causal_conv1d_fwd
flops += Batch * (Dim * L) * (R + H + H)  # x_dbl = F.linear(...
flops += Batch * (Dim * R) * (L)  # delta_proj_weight @ x_dbl[:, :delta_rank]

# https://github.com/state-spaces/mamba/issues/110
flops = 9 * Batch * L * Dim * H
if with_D:
    flops += Batch * Dim * L
if with_Z:
    flops += Batch * Dim * L

out_weight_shape = out_proj_weight.type().sizes()
assert out_weight_shape[1] == Dim
flops += Batch * Dim * L * out_weight_shape[0]

return flops
MzeroMiko commented 1 month ago

Oh, thank you for correcting this.

I did not check the code of the MambaInnerFn_jit , since it has not been used by any (in VMamba or Vim or S4ND). But I did use the code MambaInnerFnNoOutProj_jit to get the flops of Vim, and the only difference between MambaInnerFnNoOutProj and MambaInnerFn is the final projection layer. So the corrected version should works.