intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform
Apache License 2.0
1.56k stars 238 forks source link

WOQ ValueError: too many values to unpack (expected 2) with "Intel/neural-chat-7b-v3-3" #555

Closed eduand-alvarez closed 4 weeks ago

eduand-alvarez commented 6 months ago

Describe the bug

When running this:

import torch
import intel_extension_for_pytorch as ipex
from transformers import AutoTokenizer, AutoModelForCausalLM

# PART 1: Model and tokenizer loading
tokenizer = AutoTokenizer.from_pretrained("Intel/neural-chat-7b-v3-3")
model = AutoModelForCausalLM.from_pretrained("Intel/neural-chat-7b-v3-3")

# PART 2: Preparation of quantization config
qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(
  weight_dtype=torch.qint8, # or torch.quint4x2
  lowp_mode=ipex.quantization.WoqLowpMode.NONE, # or FP16, BF16, INT8
)

#checkpoint = None # optionally load int4 or int8 checkpoint

# PART 3: Model optimization and quantization
model = ipex.llm.optimize(model, quantization_config=qconfig, low_precision_checkpoint=checkpoint)

inputs = tokenizer("I love learning to code...", return_tensors="pt").to(model.device)

# PART 4: Generation inference loop
with torch.inference_mode():
    tokens = model.generate(
        **inputs,
        max_new_tokens=64,
        temperature=0.70,
        top_p=0.95,
        do_sample=True,)

print(tokenizer.decode(tokens[0], skip_special_tokens=True))

I get this error:

Loading checkpoint shards: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:04<00:00, 2.23s/it] ipex.llm.optimize is doing the weight only quantization

ValueError Traceback (most recent call last) Cell In[19], line 18 10 qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping( 11 weight_dtype=torch.qint8, # or torch.quint4x2 12 lowp_mode=ipex.quantization.WoqLowpMode.NONE, # or FP16, BF16, INT8 13 ) 15 #checkpoint = None # optionally load int4 or int8 checkpoint 16 17 # PART 3: Model optimization and quantization ---> 18 model = ipex.llm.optimize(model, quantization_config=qconfig, low_precision_checkpoint=checkpoint) 20 inputs = tokenizer("I love learning to code...", return_tensors="pt").to(model.device) 22 # PART 4: Generation inference loop

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/intel_extension_for_pytorch/transformers/optimize.py:931, in optimize(model, optimizer, dtype, inplace, device, quantization_config, qconfig_summary_file, low_precision_checkpoint, sample_inputs, deployment_mode) 922 _model = ipex_quantization_flow( 923 _model, 924 dtype, (...) 927 None, 928 ) 930 # model lowering conversion --> 931 _model = model_convert_lowering( 932 _model, 933 device, 934 dtype, 935 sample_inputs, 936 deployment_mode, 937 is_quantization, 938 is_woq, 939 ) 940 # do not register output hook when doing calibration 941 if not (is_quantization and qconfig_summary_file is None):

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/intel_extension_for_pytorch/transformers/optimize.py:666, in model_convert_lowering(_model, device, dtype, sample_inputs, deployment_mode, is_quantization, woq) 658 sample_inputs = ( 659 get_dummy_input(_model, return_dict=True) 660 if sample_inputs is None 661 else sample_inputs 662 ) 663 with torch.no_grad(), torch.cpu.amp.autocast( 664 enabled=True if dtype is torch.bfloat16 else False 665 ): --> 666 trace_model = torch.jit.trace( 667 _model, 668 example_kwarg_inputs=sample_inputs, 669 strict=False, 670 check_trace=False, 671 ) 672 trace_model = torch.jit.freeze(trace_model) 673 _model = _set_optimized_model_for_generation( 674 _model, optimized_model=trace_model 675 )

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/intel_extension_for_pytorch/jit/_trace.py:69, in jit_trace_wrapper..wrapper(*args, kwargs) 64 if torch.xpu.is_available() and need_to_disable_check_trace_for_XPU( 65 *args, *kwargs 66 ): 67 kwargs["check_trace"] = False ---> 69 traced = f(args, kwargs) 70 torch.set_autocast_cache_enabled(prev) 71 return traced

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/torch/jit/_trace.py:806, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs) 804 else: 805 raise RuntimeError("example_kwarg_inputs should be a dict") --> 806 return trace_module( 807 func, 808 {"forward": example_inputs}, 809 None, 810 check_trace, 811 wrap_check_inputs(check_inputs), 812 check_tolerance, 813 strict, 814 _force_outplace, 815 _module_class, 816 example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict), 817 _store_inputs=_store_inputs, 818 ) 819 if ( 820 hasattr(func, "self") 821 and isinstance(func.self, torch.nn.Module) 822 and func.name == "forward" 823 ): 824 if example_inputs is None:

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/torch/jit/_trace.py:1062, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs) 1057 valid_arguments = "[" + ",".join(argument_names) + "]" 1058 raise NameError( 1059 f"""'{key}' is not in forward() method's arguments, 1060 valid arguments name are {valid_arguments}""" 1061 ) -> 1062 module._c._create_method_from_trace_with_dict( 1063 method_name, 1064 func, 1065 example_inputs, 1066 var_lookup_fn, 1067 strict, 1068 _force_outplace, 1069 argument_names, 1070 _store_inputs, 1071 ) 1072 else: 1073 example_inputs = make_tuple(example_inputs)

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._slow_forward(self, *input, *kwargs) 1499 recording_scopes = False 1500 try: -> 1501 result = self.forward(input, **kwargs) 1502 finally: 1503 if recording_scopes:

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/intel_extension_for_pytorch/transformers/models/reference/models.py:1333, in MistralForCausalLM_forward(self, input_ids, attention_mask, past_key_values, position_ids, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict) 1326 output_hidden_states = ( 1327 output_hidden_states 1328 if output_hidden_states is not None 1329 else self.config.output_hidden_states 1330 ) 1332 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) -> 1333 outputs = self.model( 1334 input_ids=input_ids, 1335 attention_mask=attention_mask, 1336 position_ids=position_ids, 1337 past_key_values=past_key_values, 1338 inputs_embeds=inputs_embeds, 1339 use_cache=use_cache, 1340 output_attentions=output_attentions, 1341 output_hidden_states=output_hidden_states, 1342 return_dict=False, 1343 ) 1345 hidden_states = outputs[0] 1346 if ( 1347 hasattr(self, "config") 1348 and hasattr(self.config, "lm_head_generation") 1349 and self.config.lm_head_generation 1350 and hidden_states.size(1) != 1 1351 ):

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, kwargs) 1509 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1510 else: -> 1511 return self._call_impl(args, kwargs)

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, *kwargs) 1515 # If we don't have any hooks, we want to skip the rest of the logic in 1516 # this function, and just call forward. 1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1518 or _global_backward_pre_hooks or _global_backward_hooks 1519 or _global_forward_hooks or _global_forward_pre_hooks): -> 1520 return forward_call(args, **kwargs) 1522 try: 1523 result = None

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._slow_forward(self, *input, *kwargs) 1499 recording_scopes = False 1500 try: -> 1501 result = self.forward(input, **kwargs) 1502 finally: 1503 if recording_scopes:

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/transformers/models/mistral/modeling_mistral.py:974, in MistralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) 972 use_legacy_cache = not isinstance(past_key_values, Cache) 973 if use_legacy_cache: --> 974 past_key_values = DynamicCache.from_legacy_cache(past_key_values) 975 past_key_values_length = past_key_values.get_usable_length(seq_length) 977 if position_ids is None:

File ~/miniconda3/envs/presidio/lib/python3.9/site-packages/transformers/cache_utils.py:167, in DynamicCache.from_legacy_cache(cls, past_key_values) 165 if past_key_values is not None: 166 for layer_idx in range(len(past_key_values)): --> 167 key_states, value_states = past_key_values[layer_idx] 168 cache.update(key_states, value_states, layer_idx) 169 return cache

ValueError: too many values to unpack (expected 2)

Versions

packages in environment at /home/sdp/miniconda3/envs/presidio:

#

Name Version Build Channel

_libgcc_mutex 0.1 main
_openmp_mutex 5.1 1_gnu
anyio 4.3.0 pypi_0 pypi argon2-cffi 23.1.0 pypi_0 pypi argon2-cffi-bindings 21.2.0 pypi_0 pypi arrow 1.3.0 pypi_0 pypi asttokens 2.4.1 pypi_0 pypi async-lru 2.0.4 pypi_0 pypi attrs 23.2.0 pypi_0 pypi babel 2.14.0 pypi_0 pypi beautifulsoup4 4.12.3 pypi_0 pypi bleach 6.1.0 pypi_0 pypi boto3 1.34.58 pypi_0 pypi botocore 1.34.58 pypi_0 pypi ca-certificates 2023.12.12 h06a4308_0
certifi 2024.2.2 pypi_0 pypi cffi 1.16.0 pypi_0 pypi charset-normalizer 3.3.2 pypi_0 pypi cloudpickle 2.2.1 pypi_0 pypi comm 0.2.1 pypi_0 pypi contextlib2 21.6.0 pypi_0 pypi debugpy 1.8.1 pypi_0 pypi decorator 5.1.1 pypi_0 pypi defusedxml 0.7.1 pypi_0 pypi dill 0.3.8 pypi_0 pypi docker 7.0.0 pypi_0 pypi exceptiongroup 1.2.0 pypi_0 pypi executing 2.0.1 pypi_0 pypi fastjsonschema 2.19.1 pypi_0 pypi filelock 3.13.1 pypi_0 pypi fqdn 1.5.1 pypi_0 pypi fsspec 2024.2.0 pypi_0 pypi google-pasta 0.2.0 pypi_0 pypi h11 0.14.0 pypi_0 pypi httpcore 1.0.4 pypi_0 pypi httpx 0.27.0 pypi_0 pypi huggingface-hub 0.21.4 pypi_0 pypi idna 3.6 pypi_0 pypi importlib-metadata 6.11.0 pypi_0 pypi intel-extension-for-pytorch 2.2.0 pypi_0 pypi ipykernel 6.29.3 pypi_0 pypi ipython 8.18.1 pypi_0 pypi isoduration 20.11.0 pypi_0 pypi jedi 0.19.1 pypi_0 pypi jinja2 3.1.3 pypi_0 pypi jmespath 1.0.1 pypi_0 pypi json5 0.9.22 pypi_0 pypi jsonpointer 2.4 pypi_0 pypi jsonschema 4.21.1 pypi_0 pypi jsonschema-specifications 2023.12.1 pypi_0 pypi jupyter-client 8.6.0 pypi_0 pypi jupyter-core 5.7.1 pypi_0 pypi jupyter-events 0.9.0 pypi_0 pypi jupyter-lsp 2.2.4 pypi_0 pypi jupyter-server 2.13.0 pypi_0 pypi jupyter-server-terminals 0.5.2 pypi_0 pypi jupyterlab 4.1.4 pypi_0 pypi jupyterlab-pygments 0.3.0 pypi_0 pypi jupyterlab-server 2.25.3 pypi_0 pypi ld_impl_linux-64 2.38 h1181459_1
libffi 3.3 he6710b0_2
libgcc-ng 11.2.0 h1234567_1
libgomp 11.2.0 h1234567_1
libstdcxx-ng 11.2.0 h1234567_1
markupsafe 2.1.5 pypi_0 pypi matplotlib-inline 0.1.6 pypi_0 pypi mistune 3.0.2 pypi_0 pypi mpmath 1.3.0 pypi_0 pypi multiprocess 0.70.16 pypi_0 pypi nbclient 0.9.0 pypi_0 pypi nbconvert 7.16.2 pypi_0 pypi nbformat 5.9.2 pypi_0 pypi ncurses 6.4 h6a678d5_0
nest-asyncio 1.6.0 pypi_0 pypi networkx 3.2.1 pypi_0 pypi notebook-shim 0.2.4 pypi_0 pypi numpy 1.26.4 pypi_0 pypi nvidia-cublas-cu12 12.1.3.1 pypi_0 pypi nvidia-cuda-cupti-cu12 12.1.105 pypi_0 pypi nvidia-cuda-nvrtc-cu12 12.1.105 pypi_0 pypi nvidia-cuda-runtime-cu12 12.1.105 pypi_0 pypi nvidia-cudnn-cu12 8.9.2.26 pypi_0 pypi nvidia-cufft-cu12 11.0.2.54 pypi_0 pypi nvidia-curand-cu12 10.3.2.106 pypi_0 pypi nvidia-cusolver-cu12 11.4.5.107 pypi_0 pypi nvidia-cusparse-cu12 12.1.0.106 pypi_0 pypi nvidia-nccl-cu12 2.19.3 pypi_0 pypi nvidia-nvjitlink-cu12 12.4.99 pypi_0 pypi nvidia-nvtx-cu12 12.1.105 pypi_0 pypi openssl 1.1.1w h7f8727e_0
overrides 7.7.0 pypi_0 pypi packaging 23.2 pypi_0 pypi pandas 2.2.1 pypi_0 pypi pandocfilters 1.5.1 pypi_0 pypi parso 0.8.3 pypi_0 pypi pathos 0.3.2 pypi_0 pypi pexpect 4.9.0 pypi_0 pypi pip 23.3.1 py39h06a4308_0
platformdirs 4.2.0 pypi_0 pypi pox 0.3.4 pypi_0 pypi ppft 1.7.6.8 pypi_0 pypi prometheus-client 0.20.0 pypi_0 pypi prompt-toolkit 3.0.43 pypi_0 pypi protobuf 4.25.3 pypi_0 pypi psutil 5.9.8 pypi_0 pypi ptyprocess 0.7.0 pypi_0 pypi pure-eval 0.2.2 pypi_0 pypi pycparser 2.21 pypi_0 pypi pygments 2.17.2 pypi_0 pypi python 3.9.0 hdb3f193_2
python-dateutil 2.9.0.post0 pypi_0 pypi python-json-logger 2.0.7 pypi_0 pypi pytz 2024.1 pypi_0 pypi pyyaml 6.0.1 pypi_0 pypi pyzmq 25.1.2 pypi_0 pypi readline 8.2 h5eee18b_0
referencing 0.33.0 pypi_0 pypi regex 2023.12.25 pypi_0 pypi requests 2.31.0 pypi_0 pypi rfc3339-validator 0.1.4 pypi_0 pypi rfc3986-validator 0.1.1 pypi_0 pypi rpds-py 0.18.0 pypi_0 pypi s3transfer 0.10.0 pypi_0 pypi safetensors 0.4.2 pypi_0 pypi sagemaker 2.212.0 pypi_0 pypi schema 0.7.5 pypi_0 pypi send2trash 1.8.2 pypi_0 pypi setuptools 68.2.2 py39h06a4308_0
six 1.16.0 pypi_0 pypi smdebug-rulesconfig 1.0.1 pypi_0 pypi sniffio 1.3.1 pypi_0 pypi soupsieve 2.5 pypi_0 pypi sqlite 3.41.2 h5eee18b_0
stack-data 0.6.3 pypi_0 pypi sympy 1.12 pypi_0 pypi tblib 2.0.0 pypi_0 pypi terminado 0.18.0 pypi_0 pypi tinycss2 1.2.1 pypi_0 pypi tk 8.6.12 h1ccaba5_0
tokenizers 0.15.2 pypi_0 pypi tomli 2.0.1 pypi_0 pypi torch 2.2.0 pypi_0 pypi tornado 6.4 pypi_0 pypi tqdm 4.66.2 pypi_0 pypi traitlets 5.14.1 pypi_0 pypi transformers 4.38.2 pypi_0 pypi triton 2.2.0 pypi_0 pypi types-python-dateutil 2.8.19.20240106 pypi_0 pypi typing-extensions 4.10.0 pypi_0 pypi tzdata 2024.1 pypi_0 pypi uri-template 1.3.0 pypi_0 pypi urllib3 1.26.18 pypi_0 pypi wcwidth 0.2.13 pypi_0 pypi webcolors 1.13 pypi_0 pypi webencodings 0.5.1 pypi_0 pypi websocket-client 1.7.0 pypi_0 pypi wheel 0.41.2 py39h06a4308_0
xz 5.4.6 h5eee18b_0
zipp 3.17.0 pypi_0 pypi zlib 1.2.13 h5eee18b_0

jgong5 commented 6 months ago

@Xia-Weiwen Can you take a look?

Xia-Weiwen commented 6 months ago

Hi @eduand-alvarez. Looks like you were using IPEX 2.2 and transformers 4.38.2. However, IPEX 2.2 requires transformers==4.35.2. Could you please downgrade transformers and retry? Please find dependencies here: https://github.com/intel/intel-extension-for-pytorch/blob/release/2.2/dependency_version.yml

srinarayan-srikanthan commented 4 weeks ago

hi @eduand-alvarez , did changing the versions suggested resolve the issue? Can we close this ticket?

srinarayan-srikanthan commented 4 weeks ago

Closing as issue resolved with updating version.