Open Wintoplay opened 6 months ago
What compatibility in libraries would be needed in order to have this work successfully in ROCm?
We do not have experience with ROCm, but ofc we'd welcome community contribution on this
Hi @tridao , thanks for answering. I think there is enough of a demand in the ROCm community to make this work. I went through the files, but just to make sure I'm not missing anything, is it the case that the only dependencies are :
python_requires=">=3.7",
install_requires=[
"torch",
"packaging",
"ninja",
"einops",
"triton",
"transformers",
"causal_conv1d>=1.1.0",
],
If that's the case, this should be relatively easy to port.
There's CUDA code in causal_conv1d but that's optional, we can use torch's conv1d. There's CUDA code in this repo for the selective_scan operation (csrc
) and maybe it can work w HIP.
I found the simple one written in PyTorch. Compatible with ROCm.
We have a working version of mamba on ROCm. We've been able to run generation on AMD's MI210, and the unit tests are passing for the port. It uses pytorch's cpp extensions to selectively build kernel code based on whether the system is running CUDA or ROCm, so the port is able to be built in both systems. https://github.com/EmbeddedLLM/mamba-rocm
Please consider adding rocm support for amd gpu