alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

Why use element-wise multiplication rather than matrix multiplication in the function `selective_scan_seq` #17

Closed cszhbo closed 6 months ago

cszhbo commented 6 months ago

Hello. In the function selective_scan_seq, there are two points that I am confused:

However, in the paper, the equation is $$ht = \bar{A} h{t-1} + \bar{B} x_{t}$$ Both terms in the right side of the euation is performed in matrix multiplication.

I am curious that do the two lines of code use some tricks to convert matrix multiplication into the elementwise one?

alxndrTL commented 6 months ago

Hello, regarding the first multiplication $\bar{A} h{t-1}$, we indeed have $\bar{A}$ which is $(N,N)$ and $h{t-1}$ which is $(N,)$ so computing the multiplication would be done with a matmul. But it is assumed that $\bar{A}$ is kept diagonal throughout training, and so in the code we only represent it as a $(N,)$ vector (elements on the diagonal). And thus doing an element-wise multiplication with this $A$ represented as a vector with $h_{t-1}$ gives the correct result.

Concerning the second multiplication $\bar{B} x_t$, we have $\bar{B}$ which is a $(N,)$ vector but $x_t$ is only a scalar, so in the code we simply use an element-wise multiplication (with an unsqueeze to avoid any surprising broadcasting).

Note that in the code, in the two lines you gave, the equation $\bar{A} h_{t-1} + \bar{B} x_t$ is computed for the E*D channels in parallel, but that does not change the math behind it : if you understand it with a single channel (as explained just above in my answer) then PyTorch just does the batch multiplication for you.

Hope that is clear enough!

cszhbo commented 6 months ago

Ok, I see. Thanks for your detailed and clear explanation.