alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
960 stars 86 forks source link

Changing th pscan from in-place to out-of-place? #27

Closed llmexperiment closed 5 months ago

llmexperiment commented 5 months ago

Hi @alxndrTL ,

I am trying to generate onnx file for the forward pass here: https://github.com/alxndrTL/mamba.py/blob/main/mamba.py#L69

The issue is that the onnx export can not handle in-place computation, and thus will skip or generate incorrect output if there is a in-place computation. I am wondering how to modify the pscan so that all the computations happen in out-of-place manner?

alxndrTL commented 5 months ago

Hello, the goal of the ONNX export is to do inference right ? If so, the pscan shouldn't be used at inference. Its goal is to speed up the forward pass when the whole sequence is known. This is not the case during inference, as inputs are generated one at a time. You should thus consider either the step function (which uses a cache), or the selective_scan_seq. I would consider using the step function, as the cache greatly speeds up the computations and reduce memory usage. Hope this helps.

llmexperiment commented 5 months ago

Hi @alxndrTL ,

That is right! Inference consists of prompt processing which is part of the forward pass + step function. Currently, the pscan is implemented as in-place, thus it prevents to generate proper onnx file. any reason to implement in-place?

alxndrTL commented 5 months ago

Ok, I understand better now why you want to use the pscan. But have you first tried the ONNX export with the step function ? That seems the most important.

For the pscan without in-place modifications, have you simply tried this ?

def pscan(A, X):
    B, D, L, _ = A.size()
    num_steps = int(math.log2(L))

    # Create clones to avoid in-place modifications
    Aa = A.clone()
    Xa = X.clone()

    # up sweep (last 2 steps unfolded)
    for _ in range(num_steps-2):
        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)

        Xa[:, :, :, 1] = Xa[:, :, :, 1] + Aa[:, :, :, 1] * Xa[:, :, :, 0]
        Aa[:, :, :, 1] = Aa[:, :, :, 0] * Aa[:, :, :, 1]

        Aa = Aa[:, :, :, 1]
        Xa = Xa[:, :, :, 1]

    # Handling nodes based on the size after the loop
    if Xa.size(2) == 4:
        Xa[:, :, 1] = Xa[:, :, 1] + Aa[:, :, 1] * Xa[:, :, 0]
        Aa[:, :, 1] = Aa[:, :, 0] * Aa[:, :, 1]

        Xa[:, :, 3] = Xa[:, :, 3] + Aa[:, :, 3] * (Xa[:, :, 2] + Aa[:, :, 2] * Xa[:, :, 1])
    elif Xa.size(2) == 2:
        Xa[:, :, 1] = Xa[:, :, 1] + Aa[:, :, 1] * Xa[:, :, 0]
        return Xa
    else:
        return Xa

    # down sweep (first 2 steps unfolded)
    Aa = A.clone()[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
    Xa = X.clone()[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
    Xa[:, :, 2] = Xa[:, :, 2] + Aa[:, :, 2] * Xa[:, :, 1]
    Aa[:, :, 2] = Aa[:, :, 1] * Aa[:, :, 2]

    for k in range(num_steps-3, -1, -1):
        Aa = A.clone()[:, :, 2**k-1:L:2**k]
        Xa = X.clone()[:, :, 2**k-1:L:2**k]

        T = Xa.size(2)
        Aa = Aa.view(B, D, T//2, 2, -1)
        Xa = Xa.view(B, D, T//2, 2, -1)

        Xa[:, :, 1:, 0] = Xa[:, :, 1:, 0] + Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1]
        Aa[:, :, 1:, 0] = Aa[:, :, :-1, 1] * Aa[:, :, 1:, 0]

    return Xa

Just clone the tensors before the computations, and do the computations on the clone tensors, hence you avoid any in-place operations on X and A the input tensors. I don't have the time to test it so beware that this gives you the good result. Best