Closed cszhbo closed 6 months ago
Hello, regarding the first multiplication $\bar{A} h{t-1}$, we indeed have $\bar{A}$ which is $(N,N)$ and $h{t-1}$ which is $(N,)$ so computing the multiplication would be done with a matmul. But it is assumed that $\bar{A}$ is kept diagonal throughout training, and so in the code we only represent it as a $(N,)$ vector (elements on the diagonal). And thus doing an element-wise multiplication with this $A$ represented as a vector with $h_{t-1}$ gives the correct result.
Concerning the second multiplication $\bar{B} x_t$, we have $\bar{B}$ which is a $(N,)$ vector but $x_t$ is only a scalar, so in the code we simply use an element-wise multiplication (with an unsqueeze
to avoid any surprising broadcasting).
Note that in the code, in the two lines you gave, the equation $\bar{A} h_{t-1} + \bar{B} x_t$ is computed for the E*D
channels in parallel, but that does not change the math behind it : if you understand it with a single channel (as explained just above in my answer) then PyTorch just does the batch multiplication for you.
Hope that is clear enough!
Ok, I see. Thanks for your detailed and clear explanation.
Hello. In the function
selective_scan_seq
, there are two points that I am confused:BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
h = deltaA[:, t] * h + BX[:, t]
These two lines of code seem to be element-wise multiplication.However, in the paper, the equation is $$ht = \bar{A} h{t-1} + \bar{B} x_{t}$$ Both terms in the right side of the euation is performed in matrix multiplication.
I am curious that do the two lines of code use some tricks to convert matrix multiplication into the elementwise one?