vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
23.24k stars 3.3k forks source link

[Feature]: LoRA support for Mixtral GPTQ and AWQ #5540

Open StrikerRUS opened 1 month ago

StrikerRUS commented 1 month ago

πŸš€ The feature, motivation and pitch

Please consider adding support for GPTQ and AWQ quantized Mixtral models.

I guess that after #4012 it's technically possible.

Alternatives

No response

Additional context

My Docker compose: ``` --- version: "3.8" services: vllm-vllm: image: mirror.gcr.io/vllm/vllm-openai:v0.4.2 container_name: vllm-vllm # --model=casperhansen/mixtral-instruct-awq command: --model=TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ --download-dir=/root/.cache/huggingface/hub/ --dtype=half --gpu-memory-utilization=0.9 --enforce-eager --device=cuda --disable-log-stats --enable-lora --lora-modules mixtral-finetune-0-1-5=/root/adapters/ ports: - xxxx:8000 restart: unless-stopped healthcheck: test: /bin/bash -c "cat < /dev/null > /dev/tcp/vllm-vllm/8000" interval: 10s start_period: 2m logging: options: max-size: 500mb max-file: 4 environment: - HF_HOME=/root/.cache/huggingface/ volumes: - vllm_models:/root/.cache/huggingface/ - vllm_adapters:/root/adapters/ ipc: host deploy: resources: reservations: devices: - driver: nvidia device_ids: ['all'] capabilities: [gpu] volumes: vllm_models: driver: local driver_opts: type: 'none' o: 'bind' device: '/storage/gpt-project/Models_local/hf_local_0_1_0/' vllm_adapters: driver: local driver_opts: type: 'none' o: 'bind' device: '/storage/classifier-project/Models/Mixtral_finetune_0_1_5/checkpoint-7308/' ```
Error log: ``` /usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will β”‚ be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. β”‚ warnings.warn( β”‚ WARNING 06-14 12:29:50 config.py:1086] Casting torch.bfloat16 to torch.float16. β”‚ INFO 06-14 12:29:50 config.py:177] The model is convertible to Marlin format. Using Marlin kernel. β”‚ WARNING 06-14 12:29:50 config.py:976] gptq_marlin quantization is not tested with LoRA yet. β”‚ INFO 06-14 12:29:50 llm_engine.py:100] Initializing an LLM engine (v0.4.2) with config: model='TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQβ”‚ ', speculative_config=None, tokenizer='TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ', skip_tokenizer_init=False, tokenizer_mode=auto, revisβ”‚ ion=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir='/root/.cache/huggingfβ”‚ ace/hub/', load_format=LoadFormat.AUTO, tensor_parallel_size=1, disable_custom_all_reduce=False, quantization=gptq_marlin, enforce_eageβ”‚ r=True, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='β”‚ outlines'), seed=0, served_model_name=TheBloke/Mixtral-8x7B-Instruct-v0.1-GPTQ) β”‚ INFO 06-14 12:29:50 utils.py:660] Found nccl from library /root/.config/vllm/nccl/cu12/libnccl.so.2.18.1 β”‚ INFO 06-14 12:29:51 selector.py:27] Using FlashAttention-2 backend. β”‚ [rank0]: Traceback (most recent call last): β”‚ [rank0]: File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main β”‚ [rank0]: return _run_code(code, main_globals, None, β”‚ [rank0]: File "/usr/lib/python3.10/runpy.py", line 86, in _run_code β”‚ [rank0]: exec(code, run_globals) β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 168, in β”‚ [rank0]: engine = AsyncLLMEngine.from_engine_args( β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 366, in from_engine_args β”‚ [rank0]: engine = cls( β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 324, in __init__ β”‚ [rank0]: self.engine = self._init_engine(*args, **kwargs) β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 442, in _init_engine β”‚ [rank0]: return engine_class(*args, **kwargs) β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 160, in __init__ β”‚ [rank0]: self.model_executor = executor_class( β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/executor_base.py", line 41, in __init__ β”‚ [rank0]: self._init_executor() β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 23, in _init_executor β”‚ [rank0]: self._init_non_spec_worker() β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/gpu_executor.py", line 69, in _init_non_spec_worker β”‚ [rank0]: self.driver_worker.load_model() β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 118, in load_model β”‚ [rank0]: self.model_runner.load_model() β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 164, in load_model β”‚ [rank0]: self.model = get_model( β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model β”‚ [rank0]: return loader.load_model(model_config=model_config, β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 222, in load_model β”‚ [rank0]: model = _initialize_model(model_config, self.load_config, β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 90, in _initialize_model β”‚ [rank0]: **_get_model_initialization_kwargs( β”‚ [rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/model_loader/loader.py", line 70, in _get_model_initializaβ”‚ tion_kwargs β”‚ [rank0]: raise ValueError( β–ˆ [rank0]: ValueError: Model MixtralForCausalLM does not support LoRA, but LoRA is enabled. Support for this model may be added in the fuβ”‚ ture. If this is important to you, please open an issue on github. ```
robertgshaw2-neuralmagic commented 1 month ago

:)

hmellor commented 2 weeks ago

@StrikerRUS has the PR you mentioned handled your use case?

StrikerRUS commented 2 weeks ago

@hmellor Nope. LoRA adapters still cannot be used with quantized Mixtral models. There are no supported_lora_modules attribute in the quantized MixtralForCausalLM class. Refer to non-quantized version of MixtralForCausalLM class. https://github.com/vllm-project/vllm/blob/27902d42beeeb5828ef3243d5455a3b9af3317b3/vllm/model_executor/models/mixtral.py#L294-L300

Even after adding that attribute and adjusting method arguments, vLLM crashes with an error about tensor shape mismatch. I guess some further work should be done to bring the LoRA support.