state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
11.47k stars 938 forks source link

Strict requirement of a **diagonal** `A` #71

Open buttercutter opened 6 months ago

buttercutter commented 6 months ago

I have another question on the strict requirement of a diagonal A arising from some mathematical relationship with glorot/xavier initialization for A.

image image image image

Reference: [1] : Resurrecting Recurrent Neural Networks for Long Sequences [2] : HiPPO: Recurrent Memory with Optimal Polynomial Projections

albertfgu commented 6 months ago

I'm sorry, what's your question?

houghtonweihu commented 5 months ago

The question is: in papers, A needs to be diagonal, but in your code in https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py, you have:

    # S4D real initialization
    A = repeat(
        torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
        "n -> d n",
        d=self.d_inner,
    ).contiguous()
    A_log = torch.log(A)  # Keep A_log in fp32
   A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

Which is not diagonal.

   Thanks!

    Wei
tridao commented 5 months ago

A is technically a batch of d_inner diagonal matrices, each of size d_state x d_state. Since it's diagonal, we don't need to store all the d_state x d_state entries, we just need to store d_state entries. So here we're storing (d_inner, d_state) entries.

houghtonweihu commented 5 months ago

@tridao Thank you for clear explanation, and you may add this to your comments in the file so others can benefit it as well. Great work!