vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
30.19k stars 4.56k forks source link

[Usage]: Correct way to load lora model #8315

Open xyg-coder opened 2 months ago

xyg-coder commented 2 months ago

Your current environment

Collecting environment information...
PyTorch version: 2.2.1+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-2ubuntu1~20.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.0
Libc version: glibc-2.31

Python version: 3.8.10 (default, Jul 29 2024, 17:02:10)  [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-1063-aws-x86_64-with-glibc2.2.5
Is CUDA available: True
CUDA runtime version: 12.1.105
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A10G
GPU 1: NVIDIA A10G
GPU 2: NVIDIA A10G
GPU 3: NVIDIA A10G

Nvidia driver version: 525.105.17
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Byte Order:                      Little Endian
Address sizes:                   48 bits physical, 48 bits virtual
CPU(s):                          48
On-line CPU(s) list:             0-47
Thread(s) per core:              2
Core(s) per socket:              24
Socket(s):                       1
NUMA node(s):                    1
Vendor ID:                       AuthenticAMD
CPU family:                      23
Model:                           49
Model name:                      AMD EPYC 7R32
Stepping:                        0
CPU MHz:                         3299.816
BogoMIPS:                        5599.99
Hypervisor vendor:               KVM
Virtualization type:             full
L1d cache:                       768 KiB
L1i cache:                       768 KiB
L2 cache:                        12 MiB
L3 cache:                        96 MiB
NUMA node0 CPU(s):               0-47
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Full AMD retpoline, IBPB conditional, IBRS_FW, STIBP conditional, RSB filling
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmpe
rf tsc_known_freq pni pclmulqdq ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm cmp_legacy cr8_legacy abm sse4a misalignsse 3dnowprefetch topoext ssbd ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 r
dseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 clzero xsaveerptr wbnoinvd arat npt nrip_save rdpid

Versions of relevant libraries:
[pip3] botorch==0.6.2
[pip3] gpytorch==1.8.0
[pip3] mypy-extensions==1.0.0
[pip3] mypy-protobuf==2.4
[pip3] numpy==1.23.4
[pip3] nvidia-cublas-cu12==12.1.3.1
[pip3] nvidia-cuda-cupti-cu12==12.1.105
[pip3] nvidia-cuda-nvrtc-cu12==12.1.105
[pip3] nvidia-cuda-runtime-cu12==12.1.105
[pip3] nvidia-cudnn-cu12==8.9.2.26
[pip3] nvidia-cufft-cu12==11.0.2.54
[pip3] nvidia-curand-cu12==10.3.2.106
[pip3] nvidia-cusolver-cu12==11.4.5.107
[pip3] nvidia-cusparse-cu12==12.1.0.106
[pip3] nvidia-ml-py==12.555.43
[pip3] nvidia-nccl-cu12==2.19.3
[pip3] nvidia-nvjitlink-cu12==12.5.82
[pip3] nvidia-nvtx-cu12==12.1.105
[pip3] onnx==1.8.0
[pip3] onnx-graphsurgeon==0.3.12
[pip3] pynvml==11.5.3
[pip3] pytorch-lamb==1.0.0
[pip3] pytorch-lightning==2.4.0
[pip3] pyzmq==26.1.1
[pip3] torch==2.2.1
[pip3] torchao==0.1
[pip3] torchinfo==1.8.0
[pip3] torchmetrics==1.0.3
[pip3] torchrec==0.4.0a0+6e8cc97
[pip3] torchscript==0.2.37
[pip3] torchsnapshot==0.1.0
[pip3] torchtext==0.12.0a0+d7a34d6
[pip3] torchtune==0.1.1
[pip3] torchvision==0.17.1+cu121
[pip3] transformers==4.43.4
[pip3] triton==2.2.0
[pip3] tritonclient==2.45.0
[pip3] vllm-nccl-cu12==2.18.1.0.4.0
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: N/A
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0    GPU1    GPU2    GPU3    CPU Affinity    NUMA Affinity
GPU0     X      PHB     PHB     PHB     0-47            N/A
GPU1    PHB      X      PHB     PHB     0-47            N/A
GPU2    PHB     PHB      X      PHB     0-47            N/A
GPU3    PHB     PHB     PHB      X      0-47            N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

How would you like to use vllm

I have a model that is finetuned and saved:

model = AutoModelForCausalLM.from_pretrained(
    'mistralai/Mistral-7B-Instruct-v0.2',
    config=config,
)
lora_config = LoraConfig(r=128, target_modules=[...],)
model = get_peft_model(model, lora_config)

# ... finetune

model.save_pretrained(snapshot_dir)

And I can load the model locally by

config, unused_args = AutoConfig.from_pretrained(
    os.path.join(local_snapshot_dir, 'pretrain_model_config.json'), return_unused_kwargs=True
)
peft_config = PeftConfig.from_pretrained(local_snapshot_dir)
base_model = AutoModelForCausalLM.from_pretrained(
    peft_config.base_model_name_or_path,
    config=config,
)
model = PeftModel.from_pretrained(base_model, local_snapshot_dir)

What is the best way to use vllm to inference on this model? I meet the following exception when trying to call

model = LLM(**{
    "model": "/path_to_model_snapshot",
    "tokenizer": "mistralai/Mistral-7B-Instruct-v0.2",
    "skip_tokenizer_init": False,
    "tokenizer_mode": "auto",
    "trust_remote_code": False,
    "download_dir": None,
    "load_format": "auto",
    "dtype": "auto",
    "kv_cache_dtype": "auto",
    "quantization_param_path": None,
    "seed": 0,
    "max_model_len": 1024,
    "worker_use_ray": False,
    "pipeline_parallel_size": 1,
    "tensor_parallel_size": 1,
    "max_parallel_loading_workers": None,
    "block_size": 16,
    "enable_prefix_caching": False,
    "use_v2_block_manager": False,
    "swap_space": 4,
    "gpu_memory_utilization": 0.9,
    "max_num_batched_tokens": None,
    "max_num_seqs": 256,
    "max_logprobs": 5,
    "disable_log_stats": False,
    "revision": None,
    "code_revision": None,
    "tokenizer_revision": None,
    "quantization": None,
    "enforce_eager": False,
    "max_context_len_to_capture": 8192,
    "disable_custom_all_reduce": False,
    "tokenizer_pool_size": 0,
    "tokenizer_pool_type": "ray",
    "tokenizer_pool_extra_config": None,
    "enable_lora": True,
    "max_loras": 1,
    "max_lora_rank": 16,
    "lora_extra_vocab_size": 256,
    "max_cpu_loras": None,
    "device": "auto",
    "ray_workers_use_nsight": False,
    "num_gpu_blocks_override": None,
    "num_lookahead_slots": 0,
    "model_loader_extra_config": None,
    "image_input_type": None,
    "image_token_id": None,
    "image_input_shape": None,
    "image_feature_size": None,
    "scheduler_delay_factor": 0.0,
    "enable_chunked_prefill": False,
    "guided_decoding_backend": "outlines",
    "speculative_model": None,
    "num_speculative_tokens": None,
    "speculative_max_model_len": None,
})

Exception:

File "/usr/local/lib/python3.8/site-packages/vllm/entrypoints/llm.py", line 118, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
  File "/usr/local/lib/python3.8/site-packages/vllm/engine/llm_engine.py", line 277, in from_engine_args
    engine = cls(
  File "/usr/local/lib/python3.8/site-packages/vllm/engine/llm_engine.py", line 148, in __init__
    self.model_executor = executor_class(
  File "/usr/local/lib/python3.8/site-packages/vllm/executor/executor_base.py", line 41, in __init__
    self._init_executor()
  File "/usr/local/lib/python3.8/site-packages/vllm/executor/gpu_executor.py", line 22, in _init_executor
    self._init_non_spec_worker()
  File "/usr/local/lib/python3.8/site-packages/vllm/executor/gpu_executor.py", line 51, in _init_non_spec_worker
    self.driver_worker.load_model()
  File "/usr/local/lib/python3.8/site-packages/vllm/worker/worker.py", line 117, in load_model
    self.model_runner.load_model()
  File "/usr/local/lib/python3.8/site-packages/vllm/worker/model_runner.py", line 162, in load_model
    self.model = get_model(
  File "/usr/local/lib/python3.8/site-packages/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
    return loader.load_model(model_config=model_config,
  File "/usr/local/lib/python3.8/site-packages/vllm/model_executor/model_loader/loader.py", line 225, in load_model
    model.load_weights(
  File "/usr/local/lib/python3.8/site-packages/vllm/model_executor/models/llama.py", line 411, in load_weights
    param = params_dict[name]
KeyError: 'base_model.model.model.layers.0.mlp.down_proj.lora_A.weight'

Seems only the base model is loaded, but the lora adapter is ignored.

If I print out the model inside vllm:

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): VocabParallelEmbedding()
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (qkv_proj): QKVParallelLinear()
          (o_proj): RowParallelLinear()
          (rotary_emb): RotaryEmbedding()
          (attn): Attention()
        )
        (mlp): LlamaMLP(
          (gate_up_proj): MergedColumnParallelLinear()
          (down_proj): RowParallelLinear()
          (act_fn): SiluAndMul()
        )
        (input_layernorm): RMSNorm()
        (post_attention_layernorm): RMSNorm()
      )
    )
    (norm): RMSNorm()
  )
  (lm_head): ParallelLMHead()
  (logits_processor): LogitsProcessor()
  (sampler): Sampler()
)

Before submitting a new issue...

jeejeelee commented 2 months ago

vLLM implements Multi-LoRA, where the weights of the base_model and LoRA are separate. You can refer to this example: https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py