johnma2006 / mamba-minimal

Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
Apache License 2.0
2.54k stars 188 forks source link

About selective_scan #19

Open ZK-Zhou opened 7 months ago

ZK-Zhou commented 7 months ago

Hi, great work! Could you please explain why in selective_scan the "x = torch.zeros((b, d_in, n), device=deltaA.device)"? In addition, I am confusing on u and x.

Thanks.

Ykiiii commented 3 months ago

I have the same question as you! Why reset “x” to 0 in the selective_scan function?

About u and x, here is my understanding in selective_scan function show as x(t + 1) = Ax(t) + Bu(t) y(t) = Cx(t) + Du(t) here “u” is incoming x, here “x” is hidden variable,it can be understood as h

ZhangXG001 commented 3 weeks ago

"x = torch.zeros((b, d_in, n), device=deltaA.device)" is out of the loop, I think it is init the hidden state with 0(x = torch.zeros((b, d_in, n), device=deltaA.device)) @Ykiiii @ZK-Zhou

Ykiiii commented 3 days ago

It doesn't jump out of the mamba training loop. The hidden state being initialized to 0 when training each batch, just like RNN. My confusion is, when training with long time series, why not continue using the hidden state. It's more in line with the idea of state-space equation, isn't it? @ZhangXG001