vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
22.26k stars 3.14k forks source link

vllm load SqueezeLLM quantization model failed #3226

Open zuosong-peng opened 4 months ago

zuosong-peng commented 4 months ago

This is my env version:

torch:2.2.1
transformers: 4.39.0.dev0
vllm: custom compile at master@24aecf421a4ad5989697010963074904fead9a1b

I use SqueezeLLM quantization my llama-7B trained model and want use vllm load, below is my code and traceback

#git clone https://github.com/SqueezeAILab/SqueezeLLM.git
#git clone https://github.com/kssteven418/SqueezeLLM-gradients.git
#conda create -n sqllm-grad python=3.9 -y
#conda activate sqllm-grad
#cd SqueezeLLM-gradients
#pip install -e .
#pip install -r requirements.txt(mod torch>=2.2.1)
### Compute gradients
CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=16 python run.py --output_dir [gradients_path] --model_name_or_path [model_path]

#cd SqueezeLLM/
#pip install -e .
#cd squeezellm
python setup_cuda.py install
#cd ../quantization
### Chunk model weights and gradients
python chunk_models.py --model [model_path] --output [model_chunk_path] --model_type llama

python chunk_models.py --model [gradients_path] --output [gradients_chunk_path] --model_type llama
### (Optional for D+S quantization) Outlier configuration generation
python generate_outlier_config.py --model [model_chunk_path] --range 1.8 --output [outlier_config]
### K-means clustering
python nuq.py --bit 4 --model_type llama --model [model_chunk_path] --gradient [gradient_chunk_path] --output [lut_path] --outlier_config [outlier_config]/outlier_config_o0.45.json --sensitivity 0.05
### Packing
python pack.py --model [model_path] --wbits 4 --folder [lut_path] --save [pack_path] --include_sparse --balance

AutoModelForCausalLM can load SqueezeLLM model successfully

# load_quant from https://github.com/SqueezeAILab/SqueezeLLM/blob/main/llama.py#L136

from squeezellm.modelutils import *
from squeezellm.quant import *

def load_quant(model, checkpoint, wbits, include_sparse, topX):
    """
    topX is num_dense_channels.
    Number of dense channel used for hybrid kernel.
    """
    model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.bfloat16)

    model = model.eval()
    layers = find_layers(model)

    state_dict = torch.load(os.path.join(checkpoint, "pack_model.pt"))

    # load sparse thresholds from checkpoint
    if include_sparse:
        num_vals = {}
        for k, v in state_dict.items():
            if "sparse_threshold." in k:
                key = k.replace("sparse_threshold.", "")
                num_vals[key] = v
        for k, v in num_vals.items():
            del state_dict["sparse_threshold." + k]
    else:
        num_vals = None

    # replace layers
    for name in ["lm_head"]:
        if name in layers:
            del layers[name]
    make_quant_lut(
        model, layers, wbits, include_sparse=include_sparse, numvals=num_vals, topX=topX
    )
    del layers

    print("Loading model ...")
    state_dict = torch.load(os.path.join(checkpoint, "pack_model.pt"))
    model.load_state_dict(state_dict, strict=False)
    model.seqlen = 2048
    print("Done.")

    return model
model = load_quant("llama-2", adapter_path, 4, include_sparse=True, topX=10)
model = model.to(DEV)
model.eval()

But vllm failed to load with error

from vllm import LLM, SamplingParams
import torch
model_path = '/root/ckpt161_quantization_w4_s0.45'

if __name__ == '__main__':
    llm = LLM(model=model_path, quantization="squeezellm", dtype=torch.float16)
    prompts = [
    "Hello, my name is"
        ]
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Stacktrace

Traceback (most recent call last):
  File "/root/python/dictionary/train/testbatchvllm.py", line 58, in <module>
    llm = LLM(model=model_path, quantization="squeezellm", dtype=torch.float16)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/python/github.com/vllm/vllm/entrypoints/llm.py", line 109, in __init__
    self.llm_engine = LLMEngine.from_engine_args(engine_args)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/python/github.com/vllm/vllm/engine/llm_engine.py", line 412, in from_engine_args
    engine = cls(*engine_configs,
             ^^^^^^^^^^^^^^^^^^^^
  File "/root/python/github.com/vllm/vllm/engine/llm_engine.py", line 142, in __init__
    self._init_workers()
  File "/root/python/github.com/vllm/vllm/engine/llm_engine.py", line 200, in _init_workers
    self._run_workers("load_model")
  File "/root/python/github.com/vllm/vllm/engine/llm_engine.py", line 1086, in _run_workers
    driver_worker_output = getattr(self.driver_worker,
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/python/github.com/vllm/vllm/worker/worker.py", line 99, in load_model
    self.model_runner.load_model()
  File "/root/python/github.com/vllm/vllm/worker/model_runner.py", line 88, in load_model
    self.model = get_model(self.model_config,
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/python/github.com/vllm/vllm/model_executor/utils.py", line 52, in get_model
    return get_model_fn(model_config, device_config, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/python/github.com/vllm/vllm/model_executor/model_loader.py", line 86, in get_model
    model.load_weights(model_config.model, model_config.download_dir,
  File "/root/python/github.com/vllm/vllm/model_executor/models/llama.py", line 388, in load_weights
    param = params_dict[name]
            ~~~~~~~~~~~^^^^^^
KeyError: 'model.layers.0.self_attn.qkv_proj.rows'
Qubitium commented 4 months ago

@chooper1 Can you check this? Thanks. @SoleMY shows using SqueezeLLM modified transformer code can load and run SqueezeLLM quantized and but vllm is failing at load stage.

catid commented 2 months ago

I see the same error in another model: KeyError: 'model.layers.0.self_attn.qkv_proj.rows'

catid commented 2 months ago

It's actually located here: 'model.layers.0.self_attn.qkv_proj.qweight'

Qubitium commented 2 months ago

@catid Did you get any model to run with SqueezeLLM/VLLM? I believe this feature was never tested post-merge with VLLM and should be removed.

RyanWMHI commented 1 month ago

It's actually located here: 'model.layers.0.self_attn.qkv_proj.qweight'

if you use this key model.layers.0.self_attn.qkv_proj.qweight, it definitely report this bug rank0: Traceback (most recent call last): rank0: File "/home/ryan/vllm/benchmarks/benchmark_latency.py", line 195, in

rank0: File "/home/ryan/vllm/benchmarks/benchmark_latency.py", line 20, in main rank0: llm = LLM(model=args.model,

rank0: File "/home/ryan/vllm/vllm/entrypoints/llm.py", line 123, in init rank0: self.llm_engine = LLMEngine.from_engine_args(

rank0: File "/home/ryan/vllm/vllm/engine/llm_engine.py", line 292, in from_engine_args rank0: engine = cls(

rank0: File "/home/ryan/vllm/vllm/engine/llm_engine.py", line 160, in init rank0: self.model_executor = executor_class(

rank0: File "/home/ryan/vllm/vllm/executor/executor_base.py", line 41, in init

rank0: File "/home/ryan/vllm/vllm/executor/gpu_executor.py", line 23, in _init_executor

rank0: File "/home/ryan/vllm/vllm/executor/gpu_executor.py", line 69, in _init_non_spec_worker

rank0: File "/home/ryan/vllm/vllm/worker/worker.py", line 118, in load_model

rank0: File "/home/ryan/vllm/vllm/worker/model_runner.py", line 164, in load_model rank0: self.model = get_model(

rank0: File "/home/ryan/vllm/vllm/model_executor/model_loader/init.py", line 19, in get_model rank0: return loader.load_model(model_config=model_config,

rank0: File "/home/ryan/vllm/vllm/model_executor/model_loader/loader.py", line 224, in load_model

rank0: File "/home/ryan/vllm/vllm/model_executor/models/llama.py", line 412, in load_weights rank0: weight_loader(param, loaded_weight, shard_id) rank0: File "/home/ryan/vllm/vllm/model_executor/layers/linear.py", line 561, in weight_loader rank0: loaded_weight = loaded_weight.narrow(output_dim, start_idx,

rank0: IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1) how do you fix it?