Open Beinsezii opened 1 month ago
FPx quantization is backed by a custom CUDA kernel, so it is not available to ROCm.
https://github.com/pytorch/ao/tree/main/torchao/csrc/cuda/fp6_llm
It's strange that it runs with bfloat16 though, so perhaps it is slow precisely because it doesn't use the CUDA kernel. I don't know ROCm well enough, but maybe it's not so hard to port it to ROCm.
It actually compiles something when I install from source. I see 5 threads light up. I thought torch used the hipify script for C extensions to try and auto convert code? Usually if something isn't supported by ROCm though it'll be caught when the wheel builds I thought. Additionally the error is different when using the source compiled or pip wheel. I can fetch the pip version later but it's a lot more boring essentially just saying that the function doesn't exist.
Interesting. I don't know much about how PyTorch handle building for ROCm.
Can you run this script? https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_fp6.py
It will help to verify if you can run the FPx kernel correctly.
Same exact traceback as my original post.
The one example I know of it working on both rocm and cuda is exllama. It uses torch cpp_extensions in ext.py
and the file list is a pretty good chunk of cpp/cu sources. Combing through the code there's almost no hip/rocm specific code as the hipify script will swap out all references to libraries like CUBlas for the rocm equivalents.
Compiling ao from source using
pip install git+https://github.com/pytorch/ao.git
results in a very fun throwwhen running FPX weights using the script below
Setup is 1x 7900XTX on torch 2.5+rocm62. All other quantizations work just fine, with the exception of
float8_dynamic_activation_float8_weight
because gfx11 currently does not implement torch's_scaled_mm()
functionUsing
bfloat16
as the base dtype instead actually does run but it's wicked slow from conversions. The floatx readme states to usefloat16
so I assume that's the correct way.Python traceback traceback.txt