alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

pscan speed compared to simple for loop #57

Closed AnFreTh closed 2 weeks ago

AnFreTh commented 2 weeks ago

Great repo!

I was wondering about the pscan speed compared to a simple for loop. I assumed pscan to be faster in any scenario. However, running the selective scan steps on cpu is faster with the for loop from the selective_scan_seq variant compared to the pscan.

Simulate the data -> Sequence length 320

import torch
from mambapy.pscan import pscan

# Define the variables
B, L, ED, N = 128, 320, 128, 64  # Example dimensions

x = torch.rand(B, L, ED)
delta = torch.rand(B, L, ED)
A = torch.rand(ED, N)
B = torch.rand(B, L, N)
D = torch.rand(ED)

# Calculate deltaA, deltaB, BX, and hs as joined step in pscan and seq
deltaA = torch.exp(delta.unsqueeze(-1) * A)  # (B, L, ED, N)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2)  # (B, L, ED, N)
BX = deltaB * (x.unsqueeze(-1))  # (B, L, ED, N)

With pscan

hs = pscan(deltaA, BX)

Around 10 seconds on cpu

With seq (for loop)

_, L, _ = x.shape

h = torch.zeros(x.size(0), ED, N) # (B, ED, N)
hs = []
for t in range(0, L):
    h = deltaA[:, t] * h + BX[:, t]
    hs.append(h)

hs = torch.stack(hs, dim=1) # (B, L, ED, N)

Around 2 seconds.

Is this expected and the speed advantages only come into play during training (backward passes)?

alxndrTL commented 2 weeks ago

Hello, Yes it's expected because you're running on the CPU. They are good for doing small sequential computations very fast. On the GPU, it's preferred to do bigger computations but fewer of them (sequentially speaking). That's exactly what's the parallel scan is done for : instead of L small sequential steps, only log2(L) bigger ones (if enough parallelization).

So if you're on CPU, yes it may be a good idea to stay with the sequential scan! I should put that on the README. Eventually though, with bigger dimensions, pscan should be better even on the CPU.

Hope this helps

AnFreTh commented 2 weeks ago

Thanks for the reply! That answers it. I will close the issue.