Open ferrybaltimore opened 2 months ago
Hi @ferrybaltimore, FP8 KV cache is currently supported (with optimized kernels on MI300X and software emulation on other platforms), while FP8 computation support is awaiting this PR. Selected GEMMs are performed natively in FP8 on MI300X.
Do note however that the way we currently intend to ship FP8 support is somewhat different from how it's currently done for performance reasons. The PR authors @HaiShaw @gshtras @charlifu can provide more details.
Feel free to also check out the ROCm fork at https://github.com/ROCm/vllm which already has FP8 support, including the ability to manually tune for the best FP8 GEMM kernels for a given set of shapes. Some instructions can be found at https://github.com/ROCm/vllm/blob/main/ROCm_performance.md#fp8-quantization
Hi @mawong-amd , thanks a lot! We just got two new mi300x servers and we are trying to understand how we take max advatage of them.
I will try the https://github.com/ROCm/vllm
Hi @mawong-amd,
I tested the repository, I used quark and all ok but when running the models I got:
ValueError: torch.bfloat16 is not supported for quantization method fp8. Supported dtypes: [torch.float16, torch.uint8, torch.float8_e4m3fnuz]
I launch the model like this:
python -m vllm.entrypoints.openai.api_server --model Hermes-2-Pro-Mistral-7B --tensor-parallel-size 1 --port 8010 --host 0.0.0.0 --quantization fp8 --quantized-weights-path quark
@ferrybaltimore @mawong-amd
That is an error we fixed at: https://github.com/ROCm/vllm/blob/fp8-gemm/vllm/model_executor/layers/quantization/fp8fnuz.py#L53
You can do similar modifications to move ahead.
Note - the branch fp8-gemm
is ours used to upstream (review #6006 is still in progress), it does have more general applicable fixes.
@mawong-amd , Thanks for the quick responses.
I just tried it now and I get:
Traceback (most recent call last):
File "/opt/conda/envs/py_3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.9/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/vllm-workspace/vllm/entrypoints/openai/api_server.py", line 216, in
@ferrybaltimore
python3 quantize_quark.py --model_dir <llama2/3 checkpoint folder> --output_dir output_dir --quant_scheme w_fp8_a_fp8_o_fp8 --num_calib_data 128 --model_export vllm_adopted_safetensors --no_weight_matrix_merge
git clone https://github.com/rocm/vllm; cd vllm; git checkout 93aab3c29490071a0d9a1651bb4b06cc86ca73f1
python3 examples/fp8/extract_scales.py --quantized_model <output_dir from before> --tp_size 1 --output_dir <use model dir>
HIP_VISIBLE_DEVICES=0 python benchmarks/benchmark_throughput.py --quantization fp8 --quantized-weights-path <output_dir you used before>llama.safetensors --kv-cache-dtype fp8_e4m3 --quantization-param-path <model dir>kv_cache_scales.json --model <model dir> -tp 1
@mawong-amd , keep rocking it sir!
@mawong-amd
I got this trying the quantization of llama 3.1 instruct 405B.
Traceback (most recent call last):
File "/home/npawsys/quark-0.1.0+a9827f5/examples/torch/language_modeling/quantize_quark.py", line 304, in
@mawong-amd
Related to the quantitization on llama 3.1 70b, all ok but when launching I got this error:
python -m vllm.entrypoints.openai.api_server --model /work/work2/Meta-Llama-3.1-70B-Instruct --tensor-parallel-size 1 --port 8010 --host 0.0.0.0 --quantization fp8 --quantized-weights-path /work/work2/Meta-Llama-3.1-70B-Instruct-fp8/llama.safetensors --kv-cache-dtype fp8_e4m3 --quantization-param-path /work/work2/Meta-Llama-3.1-70B-Instruct-fp8-scales/kv_cache_scales.json
Traceback (most recent call last):
File "/opt/conda/envs/py_3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/envs/py_3.9/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/entrypoints/openai/api_server.py", line 186, in
🚀 The feature, motivation and pitch
It was not clear for me if the fp8 support is available for rocm. But I got with 5.2 :
fp8 quantization is currently not supported in ROCm.
There are plans to have it available?
Alternatives
No response
Additional context
No response