state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
11.48k stars 940 forks source link

question about Parameter+delta? #410

Open liang00fan opened 2 weeks ago

liang00fan commented 2 weeks ago

image what's the reference code about $Parameter+$?i find follow code in mamba_simple.py

self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
dt = self.dt_proj.weight @ dt.t()
dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)

question1:why need add $Parameter+$? question2:why it use self.dt_proj.weight @ dt.t()? question3:those line becomedelta = F.softplus(self.dt_proj(delta)),is that the same?

albertfgu commented 1 week ago
  1. This adds a bias term to ensure $\Delta$ is the right magnitude. See previous SSM papers for details.
  2. This is motivated in the paper; the low-rank factorization saves parameters and is a generalization of a "down projection followed by broadcast"
  3. That's ultimately what happens in spirit
liang00fan commented 1 week ago

for question1: in my option, it looks like delta is the length of discretization, in order to get the dependent thing, it has to add + $S_\Delta$, but in the paper, it says image, why remove "Parameter+",it's not the same to the code:dt = self.dt_proj.weight @ dt.t()

tridao commented 1 week ago

Parameter here is the dt_bias. Linear(x_t) means self.dt_proj.weight @ dt.t() + dt_bias. In the code we separately do self.dt_proj.weight @ dt.t() and then the dt_bias is added in a separate slep (in the (CUDA kernel).

YicongHong commented 7 hours ago

Hello @tridao @albertfgu, I am very interested in this point as well. Could you please elaborate a bit more on how adding this learnable dt_bias to any input dependent $S_{\Delta}(x_t)$ can ensure the $\Delta$ is at the right magnitude? May I ask which previous SSM paper discussed this?

Much appreciated!

tridao commented 3 hours ago

S4 code has a parameter dt that should be in the range of 1e-3 to 1e-1 (these are hyperparameters that you can change). In Mamba's case we want softplus(x @ weight + dt_bias) to be around that range. We can assume that x @ weight has zero mean at initialization, so we initialize dt_bias so that softplus(dt_bias) is in the range 1e-3 to 1e-1. Ofc this does not guarantee it will stay in this range as the model is trained, but only at initialization. https://github.com/state-spaces/mamba/blob/28b1435eb56c3082a243d23253ee7676ad737c09/mamba_ssm/modules/mamba_simple.py#L91