Open ToaTao opened 1 month ago
Hello, I'm not very familiar with ONNX, but it seems that its doesnt like Pytorch models that are written in Triton.. (see https://github.com/state-spaces/mamba/issues/33)
And as of now, the mamba 2 implementation in this repo is exactly the same as the one in mamba_ssm
(contrary to mamba1 which has been completely torchified). If you look into the closed issues, you will see people have successfully converted mamba1 models with ONNX.
But for mamba2, I don't know how to provide a solution for you problem. One of the next update for mamba.py
is to actually torchify mamba2 so at that point I think the ONNX conversion will work.
Hi, I would like to ask how to convert a mamba2 model to onnx inference. When I try to convert, I encounter an error: cols = tl.array(0, BLOCK_N), where BLOCK_N=min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) in the file layer_norm.py. Greatly looking forward to your reply, thanks.