state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
13.08k stars 1.11k forks source link

convert the Mamba model to torchscript #287

Open tzmzm opened 6 months ago

tzmzm commented 6 months ago

@tridao Hello!

I am going to convert the Mamba model to torchscript as "XXX.pt" using torch.jit.trace() function, however I encountered the following problem. It seems because the torch.jit.trace() cannot convert Mamba. What should I do to do such thing?

Traceback (most recent call last): File "/home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/jit/_script.py", line 714, in save return self._c.save(str(f), **kwargs) RuntimeError: Could not export Python function call 'MambaInnerFn'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants: /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/mamba_ssm/ops/selective_scan_interface.py(317): mamba_inner_fn /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py(146): forward /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1178): _slow_forward /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1190): _call_impl

/home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1178): _slow_forward /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1190): _call_impl /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/jit/_trace.py(976): trace_module /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/jit/_trace.py(759): trace

tridao commented 6 months ago

Sorry i have no experience with torch script

FengJungle commented 1 week ago

@tzmzm how is your work getting on please?