hustvl / Vim

Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model
Apache License 2.0
2.55k stars 159 forks source link

Activation operation of Vim Block #49

Open BBBBchan opened 3 months ago

BBBBchan commented 3 months ago

Thanks for the great work! According to the paper, there should be a 'SILU' activation operation in the Vim Block. However, when I check the following code in mamba-1p1p1/mamba_ssm/modules/mamba_simple.py, I didn't find the "activation" operation.

            elif self.bimamba_type == "v2":
                A_b = -torch.exp(self.A_b_log.float())
                out = mamba_inner_fn_no_out_proj(
                    xz,
                    self.conv1d.weight,
                    self.conv1d.bias,
                    self.x_proj.weight,
                    self.dt_proj.weight,
                    A,
                    None,  # input-dependent B
                    None,  # input-dependent C
                    self.D.float(),
                    delta_bias=self.dt_proj.bias.float(),
                    delta_softplus=True,
                )
                out_b = mamba_inner_fn_no_out_proj(
                    xz.flip([-1]),
                    self.conv1d_b.weight,
                    self.conv1d_b.bias,
                    self.x_proj_b.weight,
                    self.dt_proj_b.weight,
                    A_b,
                    None,
                    None,
                    self.D_b.float(),
                    delta_bias=self.dt_proj_b.bias.float(),
                    delta_softplus=True,
                )
                # F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
                if not self.if_devide_out:
                    out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d"), self.out_proj.weight, self.out_proj.bias)
                else:
                    out = F.linear(rearrange(out + out_b.flip([-1]), "b d l -> b l d") / 2, self.out_proj.weight, self.out_proj.bias)

I further check the mamba_inner_fn_no_out_proj method in MambaInnerFnNoOutProj class in mamba-1p1p1/mamba_ssm/ops/selective_scan_interface.py and CausalConv1dFn class incausal-conv1d/causal_conv1d/causal_conv1d_interface.py. Though there is 'activation' operation occurs in CausalConv1dFn, it seems get a 'None' value.

class MambaInnerFnNoOutProj(torch.autograd.Function):
    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
        ...
        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias,None, True)
        ...

class CausalConv1dFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, weight, bias=None, activation=None):
        if activation not in [None, "silu", "swish"]:
            raise NotImplementedError("activation must be None, silu, or swish")
        if x.stride(2) != 1 and x.stride(1) != 1:
            x = x.contiguous()
        bias = bias.contiguous() if bias is not None else None
        ctx.save_for_backward(x, weight, bias)
        ctx.activation = activation in ["silu", "swish"]
        out = causal_conv1d_cuda.causal_conv1d_fwd(x, weight, bias, ctx.activation)
        return out