Aidenzich / road-to-master

A repo to store our research footprint on AI
MIT License
19 stars 4 forks source link

Competitors to Transformer #56

Open Aidenzich opened 5 months ago

Aidenzich commented 5 months ago
Aidenzich commented 4 months ago

Mamba

image

Selective SSM

Screenshot 2024-07-16 at 8 51 44 AM

What are the benefits of SSSM?

Parallel Scan

Screenshot 2024-07-16 at 8 56 19 AM

NOTE: In Maarten's diagram, $H_n$ should be regarded as $H_n'$ because it is not the final output before the sweep-up phase. This means that in the complete computation process, H_0, like other intermediate states, needs to be further processed during the sweep-up phase.

Code Implement ```python def selective_scan(self, u, delta, A, B, C, D): (b, l, d_in) = u.shape n = A.shape[1] # Discretize continuous parameters (A, B) deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n')) deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n') # Initialize x x = torch.zeros((b, d_in, n), device=deltaA.device) ys = [] # Sweep-down phase for i in range(l): x = deltaA[:, i] * x + deltaB_u[:, i] y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in') ys.append(y) # Stack results y = torch.stack(ys, dim=1) # Perform sweep-up phase y = y + u * D return y ```

Hardware-aware

Screenshot 2024-07-16 at 9 03 38 AM

Reference

URL Description
mamba The code for selective scan in the original repo is implemented at a low level, making it harder to understand.
mamba-minimal Simplified, readable, annotated code for Mamba; not optimized for speed.
Maarten's blog The comprehensive introduction by Maarten, the author of BERTopic. The best article I have found 💯