I am going to convert the Mamba model to torchscript as "XXX.pt" using torch.jit.trace() function, however I encountered the following problem. It seems because the torch.jit.trace() cannot convert Mamba. What should I do to do such thing?
Traceback (most recent call last):
File "/home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/jit/_script.py", line 714, in save
return self._c.save(str(f), **kwargs)
RuntimeError:
Could not export Python function call 'MambaInnerFn'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants:
/home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/mamba_ssm/ops/selective_scan_interface.py(317): mamba_inner_fn
/home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py(146): forward
/home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1178): _slow_forward
/home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1190): _call_impl
@tridao Hello!
I am going to convert the Mamba model to torchscript as "XXX.pt" using torch.jit.trace() function, however I encountered the following problem. It seems because the torch.jit.trace() cannot convert Mamba. What should I do to do such thing?
Traceback (most recent call last): File "/home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/jit/_script.py", line 714, in save return self._c.save(str(f), **kwargs) RuntimeError: Could not export Python function call 'MambaInnerFn'. Remove calls to Python functions before export. Did you forget to add @script or @script_method annotation? If this is a nn.ModuleList, add it to constants: /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/mamba_ssm/ops/selective_scan_interface.py(317): mamba_inner_fn /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/mamba_ssm/modules/mamba_simple.py(146): forward /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1178): _slow_forward /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1190): _call_impl
/home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1178): _slow_forward /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/nn/modules/module.py(1190): _call_impl /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/jit/_trace.py(976): trace_module /home/a/anaconda3/envs/mamba/lib/python3.9/site-packages/torch/jit/_trace.py(759): trace