johnma2006 / mamba-minimal

Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
Apache License 2.0
2.54k stars 188 forks source link

Is use of LogA intentional? #16

Closed billxbf closed 7 months ago

billxbf commented 7 months ago

Nice implementation! I notice a small detour below and wonder if it's necessary: In this line you define self.A_log = nn.Parameter(torch.log(A)) which is only used here A = -torch.exp(self.A_log.float()) to exp it back. What's the reason for not defining A directly in the class parameter?

johnma2006 commented 7 months ago

It’s a good question! In a more general sense, your question is about parameterizations. Different parameterizations may have the same behaviour in the forward pass, but much different behaviour in the backward pass. For example,

W = Parameter(Normal(0, 1)); y = W x and W2 = Parameter(Normal(0, 0.1)); W = 10 W2; y = W * x

are similar in the initial forward pass, but much different behaviour in the backward pass; W2 essentially has a much higher learning rate.

See https://www.inference.vc/neural-tangent-kernels-some-intuition-for-kernel-gradient-descent/ Example 3 for another simple example; see NTK parameterization and maximal update parameterization for more advanced and practically useful parameterizations.

billxbf commented 7 months ago

Thanks for your clarification and links! I understand that gradient of A differs in backprop with reparametrization. Yet I'm confused why the gradient of A needs to be amplified (while B,C,D don't). Is there a specific reason?

johnma2006 commented 7 months ago

Oh sure, please see On the Parameterization and Initialization of Diagonal State Space Models Section 3.3 “Parameterization of A”

In this case, A is parameterized such that it’s always negative (so a bit different from the reason I gave above)

billxbf commented 7 months ago

Got it. Thanks a lot!