state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
12.81k stars 1.08k forks source link

support for arm64 platform #209

Closed IamShubhamGupto closed 7 months ago

IamShubhamGupto commented 7 months ago

Hello!

I want to congratulate the authors for their amazing work. Going forward, I would like to test this on the Nvidia jetson platform which uses aarch64 archiecture where we could really see a significant improvement in performance over transformers.

Since the repository currently does not support arm64/aarch64, is there a planned timeline to add this support? I am interested in the work and willing to contribute a PR as well but I will need help.

Thank again!

tridao commented 7 months ago

I'm not familiar with arm64. Hopefully someone can contribute on this front.

IamShubhamGupto commented 7 months ago

Hello Prof. @tridao

Thank you for your response! arm64 is a major cpu architecture used by mobile robots and servers. I will try to enable cross compilation for this repository in the next couple of days. Anyone else interested in working on this, please feel free to reach out below.

A little bit about the Nvidia Jetson platform - https://www.nvidia.com/en-us/autonomous-machines/embedded-systems/

IamShubhamGupto commented 7 months ago

Update:

I was able to build mamba-ssm from source and it partially worked:

steps to reproduce on Nvidia Jetson AGX Orin

But on importing using python3

Python 3.10.13 (main, Sep 11 2023, 13:18:45) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import mamba_ssm
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm/__init__.py", line 3, in <module>
    from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
  File "/home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm/ops/selective_scan_interface.py", line 11, in <module>
    import selective_scan_cuda
ModuleNotFoundError: No module named 'selective_scan_cuda'

If theres a fix for this, please let me know

Im trying to build and install from scratch next hopefully, that fixes the error commands im running

python3 setup.py build
python3 setup.py install

Will post an update here soon

IamShubhamGupto commented 7 months ago

Update running python3 setup.py install worked! here are the build logs

(vim) ultraviolet@ubuntu:~/Documents/mamba$ python setup.py install
torch.version  = 2.1.0
running install
/home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/cmd.py:66: SetuptoolsDeprecationWarning: [setup.py](http://setup.py/) install is deprecated.
!!

    ********************************************************************************
    Please avoid running ``setup.py`` directly.
    Instead, use pypa/build, pypa/installer or other
    standards-based tools.

    See <https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html> for details.
    ********************************************************************************
!!
  self.initialize_options()
/home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages/setuptools/_distutils/cmd.py:66: EasyInstallDeprecationWarning: easy_install command is deprecated.
!!

    ********************************************************************************
    Please avoid running ``setup.py`` and ``easy_install``.
    Instead, use pypa/build, pypa/installer or other
    standards-based tools.

    See <https://github.com/pypa/setuptools/issues/917> for details.
    ********************************************************************************
!!
  self.initialize_options()
running bdist_egg
running egg_info
writing mamba_ssm.egg-info/PKG-INFO
writing dependency_links to mamba_ssm.egg-info/dependency_links.txt
writing requirements to mamba_ssm.egg-info/requires.txt
writing top-level names to mamba_ssm.egg-info/top_level.txt
reading manifest file 'mamba_ssm.egg-info/SOURCES.txt'
adding license file 'LICENSE'
adding license file 'AUTHORS'
writing manifest file 'mamba_ssm.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-aarch64/egg
running install_lib
running build_py
running build_ext
/home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages/torch/utils/cpp_extension.py:424: UserWarning: There are no g++ version bounds defined for CUDA version 12.2
  warnings.warn(f'There are no {compiler_name} version bounds defined for CUDA version {cuda_str_version}')
building 'selective_scan_cuda' extension
Emitting ninja build file /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/build.ninja...
Compiling objects...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
g++ -pthread -B /home/ultraviolet/miniconda3/envs/vim/compiler_compat -shared -Wl,-rpath,/home/ultraviolet/miniconda3/envs/vim/lib -Wl,-rpath-link,/home/ultraviolet/miniconda3/envs/vim/lib -L/home/ultraviolet/miniconda3/envs/vim/lib -Wl,-rpath,/home/ultraviolet/miniconda3/envs/vim/lib -Wl,-rpath-link,/home/ultraviolet/miniconda3/envs/vim/lib -L/home/ultraviolet/miniconda3/envs/vim/lib /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan.o /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan_bwd_bf16_complex.o /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan_bwd_bf16_real.o /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan_bwd_fp16_complex.o /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan_bwd_fp16_real.o /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan_bwd_fp32_complex.o /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan_bwd_fp32_real.o /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan_fwd_bf16.o /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan_fwd_fp16.o /home/ultraviolet/Documents/mamba/build/temp.linux-aarch64-cpython-310/csrc/selective_scan/selective_scan_fwd_fp32.o -L/home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages/torch/lib -L/usr/local/cuda-12.2/lib64 -lc10 -ltorch -ltorch_cpu -ltorch_python -lcudart -lc10_cuda -ltorch_cuda -o build/lib.linux-aarch64-cpython-310/selective_scan_cuda.cpython-310-aarch64-linux-gnu.so
creating build/bdist.linux-aarch64/egg
copying build/lib.linux-aarch64-cpython-310/selective_scan_cuda.cpython-310-aarch64-linux-gnu.so -> build/bdist.linux-aarch64/egg
creating build/bdist.linux-aarch64/egg/mamba_ssm
creating build/bdist.linux-aarch64/egg/mamba_ssm/modules
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/modules/mamba_simple.py -> build/bdist.linux-aarch64/egg/mamba_ssm/modules
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/modules/init.py -> build/bdist.linux-aarch64/egg/mamba_ssm/modules
creating build/bdist.linux-aarch64/egg/mamba_ssm/utils
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/utils/generation.py -> build/bdist.linux-aarch64/egg/mamba_ssm/utils
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/utils/init.py -> build/bdist.linux-aarch64/egg/mamba_ssm/utils
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/utils/hf.py -> build/bdist.linux-aarch64/egg/mamba_ssm/utils
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/init.py -> build/bdist.linux-aarch64/egg/mamba_ssm
creating build/bdist.linux-aarch64/egg/mamba_ssm/models
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/models/init.py -> build/bdist.linux-aarch64/egg/mamba_ssm/models
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/models/mixer_seq_simple.py -> build/bdist.linux-aarch64/egg/mamba_ssm/models
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/models/config_mamba.py -> build/bdist.linux-aarch64/egg/mamba_ssm/models
creating build/bdist.linux-aarch64/egg/mamba_ssm/ops
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/ops/selective_scan_interface.py -> build/bdist.linux-aarch64/egg/mamba_ssm/ops
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/ops/init.py -> build/bdist.linux-aarch64/egg/mamba_ssm/ops
creating build/bdist.linux-aarch64/egg/mamba_ssm/ops/triton
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/ops/triton/layernorm.py -> build/bdist.linux-aarch64/egg/mamba_ssm/ops/triton
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/ops/triton/init.py -> build/bdist.linux-aarch64/egg/mamba_ssm/ops/triton
copying build/lib.linux-aarch64-cpython-310/mamba_ssm/ops/triton/selective_state_update.py -> build/bdist.linux-aarch64/egg/mamba_ssm/ops/triton
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/modules/mamba_simple.py to mamba_simple.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/modules/init.py to init.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/utils/generation.py to generation.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/utils/init.py to init.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/utils/hf.py to hf.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/init.py to init.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/models/init.py to init.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/models/mixer_seq_simple.py to mixer_seq_simple.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/models/config_mamba.py to config_mamba.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/ops/selective_scan_interface.py to selective_scan_interface.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/ops/init.py to init.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/ops/triton/layernorm.py to layernorm.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/ops/triton/init.py to init.cpython-310.pyc
byte-compiling build/bdist.linux-aarch64/egg/mamba_ssm/ops/triton/selective_state_update.py to selective_state_update.cpython-310.pyc
creating stub loader for selective_scan_cuda.cpython-310-aarch64-linux-gnu.so
byte-compiling build/bdist.linux-aarch64/egg/selective_scan_cuda.py to selective_scan_cuda.cpython-310.pyc
creating build/bdist.linux-aarch64/egg/EGG-INFO
copying mamba_ssm.egg-info/PKG-INFO -> build/bdist.linux-aarch64/egg/EGG-INFO
copying mamba_ssm.egg-info/SOURCES.txt -> build/bdist.linux-aarch64/egg/EGG-INFO
copying mamba_ssm.egg-info/dependency_links.txt -> build/bdist.linux-aarch64/egg/EGG-INFO
copying mamba_ssm.egg-info/requires.txt -> build/bdist.linux-aarch64/egg/EGG-INFO
copying mamba_ssm.egg-info/top_level.txt -> build/bdist.linux-aarch64/egg/EGG-INFO
writing build/bdist.linux-aarch64/egg/EGG-INFO/native_libs.txt
zip_safe flag not set; analyzing archive contents...
pycache.selective_scan_cuda.cpython-310: module references file
creating 'dist/mamba_ssm-1.1.4-py3.10-linux-aarch64.egg' and adding 'build/bdist.linux-aarch64/egg' to it
removing 'build/bdist.linux-aarch64/egg' (and everything under it)
Processing mamba_ssm-1.1.4-py3.10-linux-aarch64.egg
removing '/home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm-1.1.4-py3.10-linux-aarch64.egg' (and everything under it)
creating /home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm-1.1.4-py3.10-linux-aarch64.egg
Extracting mamba_ssm-1.1.4-py3.10-linux-aarch64.egg to /home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages
Adding mamba-ssm 1.1.4 to easy-install.pth file
Installed /home/ultraviolet/miniconda3/envs/vim/lib/python3.10/site-packages/mamba_ssm-1.1.4-py3.10-linux-aarch64.egg
Processing dependencies for mamba-ssm==1.1.4
Searching for zope-interface>=5
Reading https://pypi.org/simple/zope-interface/
No local packages or working download links found for zope-interface>=5
error: Could not find suitable distribution for Requirement.parse('zope-interface>=5')

It complains about zope-interface at the end but it's already installed at the latest version so Im ignoring it. Importing mamba_ssm works now and I am interested in that.