alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

scan output is different between sequential and parallel versions #26

Closed llmexperiment closed 5 months ago

llmexperiment commented 5 months ago

Dear Alex,

Thank you for the repo, and I realized that the pscan output between parallel and sequential versions are different. I compared the results of seletive_scan and selective_scan_seq from here: https://github.com/alxndrTL/mamba.py/blob/main/mamba.py#L258

I know parallel version of scan uses following: https://github.com/alxndrTL/mamba.py/blob/main/pscan.py#L37

I am wondering if you have any suggestions how to debug it?

alxndrTL commented 5 months ago

Hello, can you please be more precise about the tests you've done (what L notably) ?

This is the script I used :

# https://github.com/alxndrTL/mamba.py/issues/26

import torch

import sys
sys.path.append('..')
from mamba import MambaBlock, MambaConfig

Bs, L, D, N = 2, 64, 32, 16

config = MambaConfig(d_model=D, n_layers=0, use_cuda=True)
model = MambaBlock(config).to("cuda")

# API for selective_scan() and selective_scan_seq() 
# x : (Bs, L, ED)
# Δ : (Bs, L, ED)
# A : (ED, N)
# B : (Bs, L, N)
# C : (Bs, L, N)
# D : (ED)

# y : (Bs, L, ED)

x = torch.randn(Bs, L, 2*D).to("cuda") # x.requieres_grad = True
delta = torch.randn(Bs, L, 2*D).to("cuda")
A = torch.randn(2*D, N).to("cuda")
B = torch.randn(Bs, L, N).to("cuda")
C = torch.randn(Bs, L, N).to("cuda")
D = torch.randn(2*D,).to("cuda")

y_pscan = model.selective_scan(x, delta, A, B, C, D)
y_seq = model.selective_scan_seq(x, delta, A, B, C, D)

print(torch.allclose(y_seq, y_pscan, rtol=0.01))

And I get the same results for the two functions. Of course for larger values of L (like 512) there are some differences, for exemple we get e32 in one and inf in the other but with normalization the difference should be less obvious. To test the difference with normalization taken into account, you can compare two MambaBlock, one with pscan toggled and the other no.

Anyway it's sure that as the computation done by the two functions isn't done the same way, so the result will diverge for very large values of L. But in practice this isn't a problem I think (the Mamba paper training models with context containing millions of tokens).

llmexperiment commented 5 months ago

Hello, can you please be more precise about the tests you've done (what L notably) ?

This is the script I used :

# https://github.com/alxndrTL/mamba.py/issues/26

import torch

import sys
sys.path.append('..')
from mamba import MambaBlock, MambaConfig

Bs, L, D, N = 2, 64, 32, 16

config = MambaConfig(d_model=D, n_layers=0, use_cuda=True)
model = MambaBlock(config).to("cuda")

# API for selective_scan() and selective_scan_seq() 
# x : (Bs, L, ED)
# Δ : (Bs, L, ED)
# A : (ED, N)
# B : (Bs, L, N)
# C : (Bs, L, N)
# D : (ED)

# y : (Bs, L, ED)

x = torch.randn(Bs, L, 2*D).to("cuda") # x.requieres_grad = True
delta = torch.randn(Bs, L, 2*D).to("cuda")
A = torch.randn(2*D, N).to("cuda")
B = torch.randn(Bs, L, N).to("cuda")
C = torch.randn(Bs, L, N).to("cuda")
D = torch.randn(2*D,).to("cuda")

y_pscan = model.selective_scan(x, delta, A, B, C, D)
y_seq = model.selective_scan_seq(x, delta, A, B, C, D)

print(torch.allclose(y_seq, y_pscan, rtol=0.01))

And I get the same results for the two functions. Of course for larger values of L (like 512) there are some differences, for exemple we get e32 in one and inf in the other but with normalization the difference should be less obvious. To test the difference with normalization taken into account, you can compare two MambaBlock, one with pscan toggled and the other no.

Anyway it's sure that as the computation done by the two functions isn't done the same way, so the result will diverge for very large values of L. But in practice this isn't a problem I think (the Mamba paper training models with context containing millions of tokens).

Ok. I can see that for this specific, it matches, and for larger output, I am getting "NaN" for some values in pscan.

I'd like to understand how you implemented pscan? I started your pscan doc here (https://github.com/alxndrTL/mamba.py/blob/main/docs/pscan.ipynb), but you have only the reduction step documented. Could you please help me understand the down sweep?

The reduction step works as follows for (1D array).

`X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # input array L = X.size(0)

Xa = X

for k in range(int(math.log2(L))): T = 2 * (Xa.size(0) // 2)

# split into 2 groups of pairs of elements
Xa = Xa.view(T//2, 2) 

# for each pair, add the first to the second
Xa[:, 1].add_(Xa[:, 0])

# change the view for the next iteration
Xa = Xa[:, 1]`

Do you have the code for the downsweep for 1D version?

alxndrTL commented 5 months ago

Hello,

but you have only the reduction step documented.

the down-sweep step is actually explained (with diagrams) in the doc. For the downsweep, you just do :

for k in range(int(math.log2(L))-1, -1, -1):
    # select the correct sub-array
    Xa = X[2**k-1:L:2**k]

    # split into 2 groups of pairs
    T = 2 * (Xa.size(0) // 2)
    Xa = Xa.view(T//2, 2)

    # for each pair, add to the first element the second element of the previous pair (see diagram)
    Xa[1:, 0].add_(Xa[:-1, 1])

Doing that after the code you gave, starting with X = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) will give you tensor([ 1, 3, 6, 10, 15, 21, 28, 36]), which is indeed the cumsum of the X.

the difference between this and the one in pscan.py is : -the B, D and N dim -the one in pscan.py is optimized as the first steps are unfolded, ie not done in the for loop but done manually outside of it (-the original one in pscan.py handled sequences of arbitrary lengths, but as of now only handles power of 2 lens (like this one))