vllm-project / vllm

A high-throughput and memory-efficient inference and serving engine for LLMs
https://docs.vllm.ai
Apache License 2.0
26.9k stars 3.95k forks source link

AWQ + Marlin Error #3392

Closed hllj closed 1 month ago

hllj commented 6 months ago

I convert model follow AutoAWQ library as follow script.

  1. Quantize with Marlin
    
    from awq import AutoAWQForCausalLM
    from transformers import AutoTokenizer

model_path = 'mistralai/Mistral-7B-Instruct-v0.2' quant_path = 'mistral-instruct-v0.2-awq-marlin' quant_config = { "zero_point": False, "q_group_size": 128, "w_bit": 4, "version": "Marlin" }

Load model

model = AutoAWQForCausalLM.from_pretrained( model_path, **{"low_cpu_mem_usage": True, "use_cache": False} ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

Quantize

model.quantize(tokenizer, quant_config=quant_config)

Save quantized model

model.save_quantized(quant_path) tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')


2. Generate

```python
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer

quant_path = "./mistral-instruct-v0.2-awq-marlin"

# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

# Convert prompt to tokens
prompt_template = "[INST] {prompt} [/INST]"

prompt = "You're standing on the surface of the Earth. "\
        "You walk one mile south, one mile west and one mile north. "\
        "You end up exactly where you started. Where are you?"

tokens = tokenizer(
    prompt_template.format(prompt=prompt), 
    return_tensors='pt'
).input_ids.cuda()

# Generate output
generation_output = model.generate(
    tokens, 
    streamer=streamer,
    max_new_tokens=512
)

2 steps above work perfectly.

But when I using that quantization folder to start serving in vllm with this script.

python services/vllm/api_server.py \
    --host 0.0.0.0 \
    --port 8000 \
    --disable-log-requests \
    --model /root/AutoAWQ/mistral-instruct-v0.2-awq-marlin \
    --quantization awq \
    --trust-remote-code \
    --dtype float16 \

The error show that:

  File "/root/vllm/vllm/engine/llm_engine.py", line 101, in __init__
    self.model_executor = executor_class(model_config, cache_config,
  File "/root/vllm/vllm/executor/gpu_executor.py", line 42, in __init__
    self._init_worker()
  File "/root/vllm/vllm/executor/gpu_executor.py", line 77, in _init_worker
    self.driver_worker.load_model()
  File "/root/vllm/vllm/worker/worker.py", line 99, in load_model
    self.model_runner.load_model()
  File "/root/vllm/vllm/worker/model_runner.py", line 89, in load_model
    self.model = get_model(self.model_config,
  File "/root/vllm/vllm/model_executor/utils.py", line 52, in get_model
    return get_model_fn(model_config, device_config, **kwargs)
  File "/root/vllm/vllm/model_executor/model_loader.py", line 86, in get_model
    model.load_weights(model_config.model, model_config.download_dir,
  File "/root/vllm/vllm/model_executor/models/llama.py", line 391, in load_weights
    weight_loader(param, loaded_weight)
  File "/root/vllm/vllm/model_executor/layers/linear.py", line 556, in weight_loader
    loaded_weight = loaded_weight.narrow(input_dim, start_idx,
RuntimeError: start (0) + length (14336) exceeds dimension size (896).

It look like the weight can not load and vllm did not support AWQ with Marlin kernel properly. How can we fix this error. Thank you.

hllj commented 6 months ago

I also try to turn enforce-eager mode to True, but the error is still the same.

robertgshaw2-neuralmagic commented 6 months ago

@hllj

Loading AWQ models saved in Marlin is not currently supported. I was not aware that AWQ models could be converted to Marlin format. I will look into adding support for this

Can you post that model to Hugging Face hub? This would make it easier for me

cc @alexm-nm @simon-mo @mgoin

robertgshaw2-neuralmagic commented 6 months ago

@hllj can you post it on the hub?

hllj commented 6 months ago

@hllj can you post it on the hub?

Yes, i am posting it in huggingface hub, but the process is in the step 1.

hllj commented 6 months ago

@hllj

Loading AWQ models saved in Marlin is not currently supported. I was not aware that AWQ models could be converted to Marlin format. I will look into adding support for this

Can you post that model to Hugging Face hub? This would make it easier for me

cc @alexm-nm @simon-mo @mgoin

Here are my weight from the hub. https://huggingface.co/hllj/mistral-instruct-v0.2-awq-marlin

Quantization script:

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer

model_path = 'mistralai/Mistral-7B-Instruct-v0.2'
quant_path = 'mistral-instruct-v0.2-awq-marlin'
quant_config = { "zero_point": False, "q_group_size": 128, "w_bit": 4, "version": "Marlin" }

# Load model
model = AutoAWQForCausalLM.from_pretrained(
    model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

# Quantize
model.quantize(tokenizer, quant_config=quant_config)

# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)

print(f'Model is quantized and saved at "{quant_path}"')

You have to install Marlin kernel in this repo.

Generation script

from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer

quant_path = "hllj/mistral-instruct-v0.2-awq-marlin"

# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

# Convert prompt to tokens
prompt_template = "[INST] {prompt} [/INST]"

prompt = "You're standing on the surface of the Earth. "\
        "You walk one mile south, one mile west and one mile north. "\
        "You end up exactly where you started. Where are you?"

tokens = tokenizer(
    prompt_template.format(prompt=prompt), 
    return_tensors='pt'
).input_ids.cuda()

# Generate output
generation_output = model.generate(
    tokens, 
    streamer=streamer,
    max_new_tokens=512
)
robertgshaw2-neuralmagic commented 6 months ago

Thanks, will take a look at this over the weekend

casper-hansen commented 6 months ago

AWQ has many versions at this point. I would recommend using gemv_fast from AutoAWQ as it’s faster than Marlin and exllamav2 in decoding. However, it’s not implemented yet, although I laid out some ground work in a PR

robertgshaw2-neuralmagic commented 6 months ago

@casper-hansen Is gemv_fast just for batch 1 decoding?

casper-hansen commented 6 months ago

@casper-hansen Is gemv_fast just for batch 1 decoding?

No, it's just a name. See more details: https://github.com/vllm-project/vllm/pull/3289

robertgshaw2-neuralmagic commented 6 months ago

@casper-hansen can you point me to the place in AutoAWQ where these are implemented and loaded? I resolved the packing issues related to integrating the Marlin kernels, so I can try to do the same for AWQ for vLLM

casper-hansen commented 6 months ago

@robertgshaw2-neuralmagic It is implemented as seen below. We create buffers and then accelerate loads the weights for us. If you can help implement the weight loading, it would help me move forward to get the PR ready

https://github.com/casper-hansen/AutoAWQ/blob/main/awq/modules/linear/gemv_fast.py#L72

robertgshaw2-neuralmagic commented 6 months ago

Got it - so you do the repacking on the fly (during from_linear which is called when the model is being created) as opposed to serializing the model in the GEMVFast format?

mgoin commented 1 month ago

Now we support converting AWQ models to Marlin inside vLLM, closing!

psych0v0yager commented 1 month ago

How do you convert AWQ models to Marlin inside vLLM? The script in this thread was done outside vLLM and needs a full precision model to convert. Can the model conversion be completed on the fly or does it need to be done before? Also, how much additional VRAM is necessary to perform the conversion.

EDIT: Nevermind, vLLM can do it on the fly. Just pass in -q marlin or leave it blank if you want to use the marlin kernels. More info is here #6612