Thanks for the great work!
According to the paper, there should be a 'SILU' activation operation in the Vim Block. However, when I check the following code in mamba-1p1p1/mamba_ssm/modules/mamba_simple.py, I didn't find the "activation" operation.
elif self.bimamba_type == "v2":
A_b = -torch.exp(self.A_b_log.float())
out = mamba_inner_fn_no_out_proj(
xz,
self.conv1d.weight,
self.conv1d.bias,
self.x_proj.weight,
self.dt_proj.weight,
A,
None, # input-dependent B
None, # input-dependent C
self.D.float(),
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
)
out_b = mamba_inner_fn_no_out_proj(
xz.flip([-1]),
self.conv1d_b.weight,
self.conv1d_b.bias,
self.x_proj_b.weight,
self.dt_proj_b.weight,
A_b,
None,
None,
self.D_b.float(),
delta_bias=self.dt_proj_b.bias.float(),
delta_softplus=True,
)
# F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
if not self.if_devide_out:
out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
else:
out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2, self.out_proj.weight, self.out_proj.bias)
I further check the mamba_inner_fn_no_out_proj method in MambaInnerFnNoOutProj class in mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py and CausalConv1dFn class incausal-conv1d/causal_conv1d/causal_conv1d_interface.py. Though there is 'activation' operation occurs in CausalConv1dFn, it seems get a 'None' value.
class MambaInnerFnNoOutProj(torch.autograd.Function):
def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
...
conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
...
class CausalConv1dFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, bias=None, activation=None):
if activation not in [None, "silu", "swish"]:
raise NotImplementedError("activation must be None, silu, or swish")
if x.stride(2) != 1 and x.stride(1) != 1:
x = x.contiguous()
bias = bias.contiguous() if bias is not None else None
ctx.save_for_backward(x, weight, bias)
ctx.activation = activation in ["silu", "swish"]
out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
return out
Thanks for the great work! According to the paper, there should be a 'SILU' activation operation in the Vim Block. However, when I check the following code in
mamba-1p1p1/mamba_ssm/modules/mamba_simple.py
, I didn't find the "activation" operation.I further check the
mamba_inner_fn_no_out_proj
method inMambaInnerFnNoOutProj
class inmamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py
andCausalConv1dFn
class incausal-conv1d/causal_conv1d/causal_conv1d_interface.py
. Though there is 'activation' operation occurs inCausalConv1dFn
, it seems get a 'None' value.