casper-hansen / AutoAWQ

AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference. Documentation:
https://casper-hansen.github.io/AutoAWQ/
MIT License
1.67k stars 202 forks source link

AssertionError: Marlin kernels are not installed. Please install AWQ compatible Marlin kernels from AutoAWQ_kernels. #370

Closed DonliFly closed 6 months ago

DonliFly commented 7 months ago

run quantize and save_quantized success. but load model to generate, get the AssertionError: Marlin kernels are not installed. Please install AWQ compatible Marlin kernels from AutoAWQ_kernels. The load model to generate code under:

run quantize and save_quantized code

quantization_config = { "zero_point": False, "q_group_size": 128, "w_bit": 4, "version": "Marlin" } 
quant_model = AutoAWQForCausalLM.from_pretrained(ori_model_id, trust_remote_code=False)
tokenizer = AutoTokenizer.from_pretrained(ori_model_id, use_fast=True)
quant_model.quantize(tokenizer, quant_config=quantization_config, calib_data=in_dataset)

quant_model.save_quantized(model_save_path)
tokenizer.save_pretrained(model_save_path)

load model to generate

quant_model = AutoAWQForCausalLM.from_quantized(model_id, device_map="auto", fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=False)
inputs = tokenizer("who are you", return_tensors="pt").to(0)
result = quant_model.generate(**inputs, max_new_tokens=512, do_sample=False, num_beams=1, eos_token_id=0)
print("generate_result = ", tokenizer.decode(result[0], skip_special_tokens=True))

whether AutoAWQ no marlin_cuda? because can.t import marlin_cuda; How can I slove the Error? I have install:

autoawq=0.2.2 autoawq_kernels=0.0.6 transformers=4.37.2

How can I slove the Error?

The ALL AssertionError is under:

Traceback (most recent call last): File "test_llama_auto_awq.py", line 173, in llama_awq_predict(awq_model_save_path, test_list) File "test_llama_auto_awq.py", line 116, in llama_awq_predict result = quant_model.generate(inputs, max_new_tokens=512, do_sample=False, num_beams=1, eos_token_id=0) File "/opt/conda/lib/python3.8/site-packages/awq/models/base.py", line 104, in generate return self.model.generate(*args, *kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(args, kwargs) File "/opt/conda/lib/python3.8/site-packages/transformers/generation/utils.py", line 1479, in generate return self.greedy_search( File "/opt/conda/lib/python3.8/site-packages/transformers/generation/utils.py", line 2340, in greedy_search outputs = self( File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1183, in forward outputs = self.model( File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 1070, in forward layer_outputs = decoder_layer( File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(args, kwargs) File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 798, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "/opt/conda/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 693, in forward query_states = self.q_proj(hidden_states) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, *kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(args, **kwargs) File "/opt/conda/lib/python3.8/site-packages/awq/modules/linear/marlin.py", line 182, in forward assert MARLIN_INSTALLED, ( AssertionError: Marlin kernels are not installed. Please install AWQ compatible Marlin kernels from AutoAWQ_kernels.

casper-hansen commented 7 months ago

Did you try to install the Marlin kernels? https://github.com/IST-DASLab/marlin

yyfcc17 commented 6 months ago

@casper-hansen after installed from the marlin repo, the package named marlin, not marlin_cuda, code may need a fix

yyfcc17 commented 6 months ago

and also, repack weights when loading to be compatible with marlin would be a better choice? so that we only need one model file, with different configs, we can run on different backends.

DonliFly commented 6 months ago

Did you try to install the Marlin kernels? https://github.com/IST-DASLab/marlin

Thanks very much. I try to install Marlin kernels, the problem is fixed