Open buttercutter opened 6 months ago
I'm sorry, what's your question?
The question is: in papers, A needs to be diagonal, but in your code in https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py, you have:
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
Which is not diagonal.
Thanks!
Wei
A is technically a batch of d_inner diagonal matrices, each of size d_state x d_state. Since it's diagonal, we don't need to store all the d_state x d_state entries, we just need to store d_state entries. So here we're storing (d_inner, d_state) entries.
@tridao Thank you for clear explanation, and you may add this to your comments in the file so others can benefit it as well. Great work!
I have another question on the strict requirement of a diagonal
A
arising from some mathematical relationship with glorot/xavier initialization forA
.Reference: [1] : Resurrecting Recurrent Neural Networks for Long Sequences [2] : HiPPO: Recurrent Memory with Optimal Polynomial Projections