state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
11.39k stars 924 forks source link

Rocm support #65

Open Wintoplay opened 6 months ago

Wintoplay commented 6 months ago

Please consider adding rocm support for amd gpu

j-dominguez9 commented 5 months ago

What compatibility in libraries would be needed in order to have this work successfully in ROCm?

tridao commented 5 months ago

We do not have experience with ROCm, but ofc we'd welcome community contribution on this

j-dominguez9 commented 5 months ago

Hi @tridao , thanks for answering. I think there is enough of a demand in the ROCm community to make this work. I went through the files, but just to make sure I'm not missing anything, is it the case that the only dependencies are :

    python_requires=">=3.7",
    install_requires=[
        "torch",
        "packaging",
        "ninja",
        "einops",
        "triton",
        "transformers",
        "causal_conv1d>=1.1.0",
    ],

If that's the case, this should be relatively easy to port.

tridao commented 5 months ago

There's CUDA code in causal_conv1d but that's optional, we can use torch's conv1d. There's CUDA code in this repo for the selective_scan operation (csrc) and maybe it can work w HIP.

supersonictw commented 2 months ago

I found the simple one written in PyTorch. Compatible with ROCm.

https://github.com/alxndrTL/mamba.py/issues/22

kliuae commented 2 months ago

We have a working version of mamba on ROCm. We've been able to run generation on AMD's MI210, and the unit tests are passing for the port. It uses pytorch's cpp extensions to selectively build kernel code based on whether the system is running CUDA or ROCm, so the port is able to be built in both systems. https://github.com/EmbeddedLLM/mamba-rocm