vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
27.05k stars 3.97k forks source link

Fp8 support for mi300x #6576

Open ferrybaltimore opened 2 months ago

ferrybaltimore commented 2 months ago

🚀 The feature, motivation and pitch

It was not clear for me if the fp8 support is available for rocm. But I got with 5.2 :

fp8 quantization is currently not supported in ROCm.

There are plans to have it available?

Alternatives

No response

Additional context

No response

mawong-amd commented 2 months ago

Hi @ferrybaltimore, FP8 KV cache is currently supported (with optimized kernels on MI300X and software emulation on other platforms), while FP8 computation support is awaiting this PR. Selected GEMMs are performed natively in FP8 on MI300X.

Do note however that the way we currently intend to ship FP8 support is somewhat different from how it's currently done for performance reasons. The PR authors @HaiShaw @gshtras @charlifu can provide more details.

Feel free to also check out the ROCm fork at https://github.com/ROCm/vllm which already has FP8 support, including the ability to manually tune for the best FP8 GEMM kernels for a given set of shapes. Some instructions can be found at https://github.com/ROCm/vllm/blob/main/ROCm_performance.md#fp8-quantization

ferrybaltimore commented 2 months ago

Hi @mawong-amd , thanks a lot! We just got two new mi300x servers and we are trying to understand how we take max advatage of them.

I will try the https://github.com/ROCm/vllm

ferrybaltimore commented 2 months ago

Hi @mawong-amd,

I tested the repository, I used quark and all ok but when running the models I got:

ValueError: torch.bfloat16 is not supported for quantization method fp8. Supported dtypes: [torch.float16, torch.uint8, torch.float8_e4m3fnuz]

I launch the model like this:

python -m vllm.entrypoints.openai.api_server --model Hermes-2-Pro-Mistral-7B --tensor-parallel-size 1 --port 8010 --host 0.0.0.0 --quantization fp8 --quantized-weights-path quark

HaiShaw commented 2 months ago

@ferrybaltimore @mawong-amd

That is an error we fixed at: https://github.com/ROCm/vllm/blob/fp8-gemm/vllm/model_executor/layers/quantization/fp8fnuz.py#L53

You can do similar modifications to move ahead.

Note - the branch fp8-gemm is ours used to upstream (review #6006 is still in progress), it does have more general applicable fixes.

ferrybaltimore commented 2 months ago

@mawong-amd , Thanks for the quick responses.

I just tried it now and I get:

Traceback (most recent call last): File "/opt/conda/envs/py_3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/conda/envs/py_3.9/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/vllm-workspace/vllm/entrypoints/openai/api_server.py", line 216, in engine = AsyncLLMEngine.from_engine_args( File "/vllm-workspace/vllm/engine/async_llm_engine.py", line 431, in from_engine_args engine = cls( File "/vllm-workspace/vllm/engine/async_llm_engine.py", line 360, in init self.engine = self._init_engine(*args, *kwargs) File "/vllm-workspace/vllm/engine/async_llm_engine.py", line 507, in _init_engine return engine_class(args, kwargs) File "/vllm-workspace/vllm/engine/llm_engine.py", line 243, in init self.model_executor = executor_class( File "/vllm-workspace/vllm/executor/executor_base.py", line 128, in init super().init(model_config, cache_config, parallel_config, File "/vllm-workspace/vllm/executor/executor_base.py", line 42, in init self._init_executor() File "/vllm-workspace/vllm/executor/gpu_executor.py", line 22, in _init_executor self.driver_worker = self._create_worker() File "/vllm-workspace/vllm/executor/gpu_executor.py", line 68, in _create_worker wrapper.init_worker(self._get_worker_kwargs(local_rank, rank, File "/vllm-workspace/vllm/worker/worker_base.py", line 332, in init_worker mod = importlib.import_module(self.worker_module_name) File "/opt/conda/envs/py_3.9/lib/python3.9/importlib/init.py", line 127, in import_module return _bootstrap._gcd_import(name[level:], package, level) File "", line 1030, in _gcd_import File "", line 1007, in _find_and_load File "", line 986, in _find_and_load_unlocked File "", line 680, in _load_unlocked File "", line 850, in exec_module File "", line 228, in _call_with_frames_removed File "/vllm-workspace/vllm/worker/worker.py", line 17, in from vllm.model_executor.model_loader.tensorizer import TensorizerConfig File "/vllm-workspace/vllm/model_executor/model_loader/init.py", line 8, in from vllm.model_executor.model_loader.loader import (BaseModelLoader, File "/vllm-workspace/vllm/model_executor/model_loader/loader.py", line 38, in from vllm.utils import get_device_capability_stateless, is_hip, is_tpu ImportError: cannot import name 'get_device_capability_stateless' from 'vllm.utils' (/vllm-workspace/vllm/utils.py)


Eliovp commented 1 month ago

@ferrybaltimore

8BModelFP8Cat

@mawong-amd , keep rocking it sir!

ferrybaltimore commented 1 month ago

@mawong-amd

I got this trying the quantization of llama 3.1 instruct 405B.

Traceback (most recent call last): File "/home/npawsys/quark-0.1.0+a9827f5/examples/torch/language_modeling/quantize_quark.py", line 304, in main(args) File "/home/npawsys/quark-0.1.0+a9827f5/examples/torch/language_modeling/quantize_quark.py", line 195, in main model, model_dtype = get_model(args.model_dir, args.data_type, args.device) File "/home/npawsys/quark-0.1.0+a9827f5/examples/torch/language_modeling/quantize_quark.py", line 145, in get_model model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map=device, torch_dtype=model_dtype, trust_remote_code=True) File "/home/npawsys/miniconda3/envs/quark/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained return model_class.from_pretrained( File "/home/npawsys/miniconda3/envs/quark/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3903, in from_pretrained ) = cls._load_pretrained_model( File "/home/npawsys/miniconda3/envs/quark/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4377, in _load_pretrained_model new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( File "/home/npawsys/miniconda3/envs/quark/lib/python3.10/site-packages/transformers/modeling_utils.py", line 933, in _load_state_dict_into_meta_model set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) File "/home/npawsys/miniconda3/envs/quark/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 404, in set_module_tensor_to_device new_value = value.to(device) torch.cuda.OutOfMemoryError: HIP out of memory. Tried to allocate 1.62 GiB. GPU

ferrybaltimore commented 1 month ago

@mawong-amd

Related to the quantitization on llama 3.1 70b, all ok but when launching I got this error:

python -m vllm.entrypoints.openai.api_server --model /work/work2/Meta-Llama-3.1-70B-Instruct --tensor-parallel-size 1 --port 8010 --host 0.0.0.0 --quantization fp8 --quantized-weights-path /work/work2/Meta-Llama-3.1-70B-Instruct-fp8/llama.safetensors --kv-cache-dtype fp8_e4m3 --quantization-param-path /work/work2/Meta-Llama-3.1-70B-Instruct-fp8-scales/kv_cache_scales.json

Traceback (most recent call last): File "/opt/conda/envs/py_3.9/lib/python3.9/runpy.py", line 197, in _run_module_as_main return _run_code(code, main_globals, None, File "/opt/conda/envs/py_3.9/lib/python3.9/runpy.py", line 87, in _run_code exec(code, run_globals) File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/entrypoints/openai/api_server.py", line 186, in engine = AsyncLLMEngine.from_engine_args( File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 362, in from_engine_args engine_config = engine_args.create_engine_config() File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/engine/arg_utils.py", line 568, in create_engine_config model_config = ModelConfig( File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/config.py", line 136, in init self.max_model_len = _get_and_verify_max_len( File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/vllm/config.py", line 1232, in _get_and_verify_max_len if rope_scaling is not None and rope_scaling["type"] != "su": KeyError: 'type'