Closed llmexperiment closed 5 months ago
Hello, the goal of the ONNX export is to do inference right ?
If so, the pscan
shouldn't be used at inference. Its goal is to speed up the forward pass when the whole sequence is known. This is not the case during inference, as inputs are generated one at a time.
You should thus consider either the step
function (which uses a cache), or the selective_scan_seq
. I would consider using the step
function, as the cache greatly speeds up the computations and reduce memory usage.
Hope this helps.
Hi @alxndrTL ,
That is right! Inference consists of prompt processing which is part of the forward pass + step function. Currently, the pscan is implemented as in-place, thus it prevents to generate proper onnx file. any reason to implement in-place?
Ok, I understand better now why you want to use the pscan
. But have you first tried the ONNX export with the step
function ? That seems the most important.
For the pscan
without in-place modifications, have you simply tried this ?
def pscan(A, X):
B, D, L, _ = A.size()
num_steps = int(math.log2(L))
# Create clones to avoid in-place modifications
Aa = A.clone()
Xa = X.clone()
# up sweep (last 2 steps unfolded)
for _ in range(num_steps-2):
T = Xa.size(2)
Aa = Aa.view(B, D, T//2, 2, -1)
Xa = Xa.view(B, D, T//2, 2, -1)
Xa[:, :, :, 1] = Xa[:, :, :, 1] + Aa[:, :, :, 1] * Xa[:, :, :, 0]
Aa[:, :, :, 1] = Aa[:, :, :, 0] * Aa[:, :, :, 1]
Aa = Aa[:, :, :, 1]
Xa = Xa[:, :, :, 1]
# Handling nodes based on the size after the loop
if Xa.size(2) == 4:
Xa[:, :, 1] = Xa[:, :, 1] + Aa[:, :, 1] * Xa[:, :, 0]
Aa[:, :, 1] = Aa[:, :, 0] * Aa[:, :, 1]
Xa[:, :, 3] = Xa[:, :, 3] + Aa[:, :, 3] * (Xa[:, :, 2] + Aa[:, :, 2] * Xa[:, :, 1])
elif Xa.size(2) == 2:
Xa[:, :, 1] = Xa[:, :, 1] + Aa[:, :, 1] * Xa[:, :, 0]
return Xa
else:
return Xa
# down sweep (first 2 steps unfolded)
Aa = A.clone()[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
Xa = X.clone()[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
Xa[:, :, 2] = Xa[:, :, 2] + Aa[:, :, 2] * Xa[:, :, 1]
Aa[:, :, 2] = Aa[:, :, 1] * Aa[:, :, 2]
for k in range(num_steps-3, -1, -1):
Aa = A.clone()[:, :, 2**k-1:L:2**k]
Xa = X.clone()[:, :, 2**k-1:L:2**k]
T = Xa.size(2)
Aa = Aa.view(B, D, T//2, 2, -1)
Xa = Xa.view(B, D, T//2, 2, -1)
Xa[:, :, 1:, 0] = Xa[:, :, 1:, 0] + Aa[:, :, 1:, 0] * Xa[:, :, :-1, 1]
Aa[:, :, 1:, 0] = Aa[:, :, :-1, 1] * Aa[:, :, 1:, 0]
return Xa
Just clone the tensors before the computations, and do the computations on the clone tensors, hence you avoid any in-place operations on X and A the input tensors. I don't have the time to test it so beware that this gives you the good result. Best
Hi @alxndrTL ,
I am trying to generate onnx file for the forward pass here: https://github.com/alxndrTL/mamba.py/blob/main/mamba.py#L69
The issue is that the onnx export can not handle in-place computation, and thus will skip or generate incorrect output if there is a in-place computation. I am wondering how to modify the pscan so that all the computations happen in out-of-place manner?