vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
31.11k stars 4.73k forks source link

[Feature]: ROCm 6.2 support & FP8 Support #7469

Open ferrybaltimore opened 3 months ago

ferrybaltimore commented 3 months ago

🚀 The feature, motivation and pitch

Last week AMD announced rocm 6.2 (https://rocm.docs.amd.com/en/latest/about/release-notes.html) also announcing expanded support for VLLM & FP8.

Actually I was able to run it following the guides ( Rocm branch ) and executing it like this:

python -m vllm.entrypoints.openai.api_server --model /work/work2/Meta-Llama-3.1-70B-Instruct --tensor-parallel-size 1 --port 8010 --host 0.0.0.0 --quantization fp8 --quantized-weights-path /work/work2/Meta-Llama-3.1-70B-Instruct-fp8/llama.safetensors --kv-cache-dtype fp8_e4m3 --quantization-param-path /work/work2/Meta-Llama-3.1-70B-Instruct-fp8-scales/kv_cache_scales.json

But the performance is like 3/4 times slower than using the model withouth quantitization.

I don't know if ROCm 6.2 can solve thsi issues ... actually the performance we got with mi300x(half) is similar than running the a A100(FP8) on our tests.

Alternatives

No response

Additional context

No response

hongxiayang commented 3 months ago

Did you run the fp8 gemm tuning? This is available in ROCm/vllm only at this point:

Step1: To obtain all the shapes of gemms during the execution of the model, set the env value TUNE_FP8=1 and then run the model as usual. We will get the a file called /tmp/fp8_shapes.csv.

Step2: Next, run gradlib to obtain the best solutions of these shapes:

python3 gradlib/gradlib/gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv --indtype fp8 --outdtype f16 where /tmp/tuned_fp8_16 will be used by our fp8 gemm linear layer.

Step 3: Now, when running inference with fp8, we are using the tuned gemm for best performance.

ferrybaltimore commented 2 months ago

Hi @hongxiayang , but when I launch vllm using python -m vllm.entrypoints.openai.api_server how I specify the csv to be used? Because I had to tune each model separatly, right?

nelsonspbr commented 1 week ago

Is this TUNE_FP8 environment variable processed in vLLM's code? I couldn't find a reference to it anywhere. Setting it in my runs doesn't seem to make any difference.

hongxiayang commented 6 days ago

fp8-tuning is obsolete now. Recommend to use Pytorch tunableOps.

1: Enable TunableOp and tuning. Optionally enable verbose mode: PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_VERBOSE=1 your_vllm_script.sh

  1. Enable TunableOp and disable tuning and measure. PYTORCH_TUNABLEOP_ENABLED=1 PYTORCH_TUNABLEOP_TUNING=0 your_vllm_script.sh