yuanzhoulvpi2017 / zero_nlp

中文nlp解决方案(大模型、数据、模型、训练、推理)
MIT License
2.85k stars 355 forks source link

quantize(8) 量化后的原始模型和lora checkpoints 不匹配,报错 RuntimeError: self and mat2 must have the same dtype,怎么统一起来? #101

Open alexhmyang opened 1 year ago

alexhmyang commented 1 year ago

https://github.com/yuanzhoulvpi2017/zero_nlp/blob/main/simple_thu_chatglm6b/infer.ipynb


 E:\proj2022\chatGLM\v2\ChatGLM-6B-main\ChatGLM-6B-main\lora_infer.py:27 in <module>              │
│                                                                                                  │
│   24 text ="为什么冰红茶和柠檬茶的味道一样?"                                                    │
│   25                                                                                             │
│   26 with torch.autocast("cuda"):                                                                │
│ ❱ 27 │   res, history = model.chat(tokenizer=tokenizer, query=text,max_length=300)               │
│   28 │   print(res)                                                                              │
│   29                                                                                             │
│   30                                                                                             │
│                                                                                                  │
│ C:\Users\86187\.conda\envs\torch\lib\site-packages\torch\utils\_contextlib.py:115 in             │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ C:\Users\86187/.cache\huggingface\modules\transformers_modules\model\modeling_chatglm.py:1255 in │
│ chat                                                                                             │
│                                                                                                  │
│   1252 │   │   │   prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)              │
│   1253 │   │   inputs = tokenizer([prompt], return_tensors="pt")                                 │
│   1254 │   │   inputs = inputs.to(self.device)                                                   │
│ ❱ 1255 │   │   outputs = self.generate(**inputs, **gen_kwargs)                                   │
│   1256 │   │   outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]                       │
│   1257 │   │   response = tokenizer.decode(outputs)                                              │
│   1258 │   │   response = self.process_response(response)                                        │
│                                                                                                  │
│ C:\Users\86187\.conda\envs\torch\lib\site-packages\torch\utils\_contextlib.py:115 in             │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ C:\Users\86187\AppData\Roaming\Python\Python39\site-packages\transformers\generation\utils.py:14 │
│ 52 in generate                                                                                   │
│                                                                                                  │
│   1449 │   │   │   )                                                                             │
│   1450 │   │   │                                                                                 │
│   1451 │   │   │   # 13. run sample                                                              │
│ ❱ 1452 │   │   │   return self.sample(                                                           │
│   1453 │   │   │   │   input_ids,                                                                │
│   1454 │   │   │   │   logits_processor=logits_processor,                                        │
│   1455 │   │   │   │   logits_warper=logits_warper,                                              │
│                                                                                                  │
│ C:\Users\86187\AppData\Roaming\Python\Python39\site-packages\transformers\generation\utils.py:24 │
│ 68 in sample                                                                                     │
│                                                                                                  │
│   2465 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)  │
│   2466 │   │   │                                                                                 │
│   2467 │   │   │   # forward pass to get next token                                              │
│ ❱ 2468 │   │   │   outputs = self(                                                               │
│   2469 │   │   │   │   **model_inputs,                                                           │
│   2470 │   │   │   │   return_dict=True,                                                         │
│   2471 │   │   │   │   output_attentions=output_attentions,                                      │
│                                                                                                  │
│ C:\Users\86187\.conda\envs\torch\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ C:\Users\86187/.cache\huggingface\modules\transformers_modules\model\modeling_chatglm.py:1160 in │
│ forward                                                                                          │
│                                                                                                  │
│   1157 │   │   use_cache = use_cache if use_cache is not None else self.config.use_cache         │
│   1158 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return  │
│   1159 │   │                                                                                     │
│ ❱ 1160 │   │   transformer_outputs = self.transformer(                                           │
│   1161 │   │   │   input_ids=input_ids,                                                          │
│   1162 │   │   │   position_ids=position_ids,                                                    │
│   1163 │   │   │   attention_mask=attention_mask,                                                │
│                                                                                                  │
│ C:\Users\86187\.conda\envs\torch\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ C:\Users\86187/.cache\huggingface\modules\transformers_modules\model\modeling_chatglm.py:973 in  │
│ forward                                                                                          │
│                                                                                                  │
│    970 │   │   │   │   │   output_attentions                                                     │
│    971 │   │   │   │   )                                                                         │
│    972 │   │   │   else:                                                                         │
│ ❱  973 │   │   │   │   layer_ret = layer(                                                        │
│    974 │   │   │   │   │   hidden_states,                                                        │
│    975 │   │   │   │   │   position_ids=position_ids,                                            │
│    976 │   │   │   │   │   attention_mask=attention_mask,                                        │
│                                                                                                  │
│ C:\Users\86187\.conda\envs\torch\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ C:\Users\86187/.cache\huggingface\modules\transformers_modules\model\modeling_chatglm.py:614 in  │
│ forward                                                                                          │
│                                                                                                  │
│    611 │   │   attention_input = self.input_layernorm(hidden_states)                             │
│    612 │   │                                                                                     │
│    613 │   │   # Self attention.                                                                 │
│ ❱  614 │   │   attention_outputs = self.attention(                                               │
│    615 │   │   │   attention_input,                                                              │
│    616 │   │   │   position_ids,                                                                 │
│    617 │   │   │   attention_mask=attention_mask,                                                │
│                                                                                                  │
│ C:\Users\86187\.conda\envs\torch\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ C:\Users\86187/.cache\huggingface\modules\transformers_modules\model\modeling_chatglm.py:439 in  │
│ forward                                                                                          │
│                                                                                                  │
│    436 │   │   """                                                                               │
│    437 │   │                                                                                     │
│    438 │   │   # [seq_len, batch, 3 * hidden_size]                                               │
│ ❱  439 │   │   mixed_raw_layer = self.query_key_value(hidden_states)                             │
│    440 │   │                                                                                     │
│    441 │   │   # [seq_len, batch, 3 * hidden_size] --> [seq_len, batch, num_attention_heads, 3   │
│    442 │   │   new_tensor_shape = mixed_raw_layer.size()[:-1] + (                                │
│                                                                                                  │
│ C:\Users\86187\.conda\envs\torch\lib\site-packages\torch\nn\modules\module.py:1501 in _call_impl │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ C:\Users\86187\.conda\envs\torch\lib\site-packages\peft\tuners\lora.py:454 in forward            │
│                                                                                                  │
│   451 │   │   elif self.merged:                                                                  │
│   452 │   │   │   return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bi   │
│   453 │   │   else:                                                                              │
│ ❱ 454 │   │   │   result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.   │
│   455 │   │   │   if self.r > 0:                                                                 │
│   456 │   │   │   │   after_A = self.lora_A(self.lora_dropout(x))                                │
│   457 │   │   │   │   after_B = self.lora_B(after_A.transpose(-2, -1)).transpose(-2, -1)         │
╰────────────────────────────────────────────────────────

──────────────────────────────────────────╯ RuntimeError: self and mat2 must have the same dtype

lingxide commented 1 year ago

同样遇到这个问题。在FP16精度下训练的LoRA无法在INT8上读入。 请问是否有读入的方法,或者如何让模型在INT8的精度下训练LoRA?