alxndrTL / mamba.py

A simple and efficient Mamba implementation in pure PyTorch and MLX.
MIT License
1.01k stars 92 forks source link

Question about converting mamba2 to onnx #56

Open ToaTao opened 1 month ago

ToaTao commented 1 month ago

Hi, I would like to ask how to convert a mamba2 model to onnx inference. When I try to convert, I encounter an error: cols = tl.array(0, BLOCK_N), where BLOCK_N=min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) in the file layer_norm.py. Greatly looking forward to your reply, thanks.

alxndrTL commented 1 month ago

Hello, I'm not very familiar with ONNX, but it seems that its doesnt like Pytorch models that are written in Triton.. (see https://github.com/state-spaces/mamba/issues/33) And as of now, the mamba 2 implementation in this repo is exactly the same as the one in mamba_ssm (contrary to mamba1 which has been completely torchified). If you look into the closed issues, you will see people have successfully converted mamba1 models with ONNX. But for mamba2, I don't know how to provide a solution for you problem. One of the next update for mamba.py is to actually torchify mamba2 so at that point I think the ONNX conversion will work.