THUDM / ChatGLM-6B

ChatGLM-6B: An Open Bilingual Dialogue Language Model | 开源双语对话语言模型
Apache License 2.0
39.96k stars 5.15k forks source link

[BUG/Help] <title> can't inference based on "inputs_embeds" #1428

Open ywugwu opened 7 months ago

ywugwu commented 7 months ago

Is there an existing issue for this?

Current Behavior

Forwarding from embeddings:

input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt').cuda()
embeddings = model.get_input_embeddings()(input_ids)
model(inputs_embeds=embeddings)

leads to error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-15-8dbf228dd452>](https://localhost:8080/#) in <cell line: 5>()
      3 # tokenized_text = tokenizer.convert_ids_to_tokens(input_ids[0])
      4 embeddings = model.get_input_embeddings()(input_ids)
----> 5 model(inputs_embeds=embeddings)
      6 # model(input_ids=input_ids)

7 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py](https://localhost:8080/#) in new_forward(module, *args, **kwargs)
    162                 output = module._old_forward(*args, **kwargs)
    163         else:
--> 164             output = module._old_forward(*args, **kwargs)
    165         return module._hf_hook.post_forward(module, output)
    166 

[~/.cache/huggingface/modules/transformers_modules/THUDM/chatglm2-6b/7fabe56db91e085c9c027f56f1c654d137bdba40/modeling_chatglm.py](https://localhost:8080/#) in forward(self, input_ids, position_ids, attention_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, return_last_logit)
    935         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    936 
--> 937         transformer_outputs = self.transformer(
    938             input_ids=input_ids,
    939             position_ids=position_ids,

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/accelerate/hooks.py](https://localhost:8080/#) in new_forward(module, *args, **kwargs)
    162                 output = module._old_forward(*args, **kwargs)
    163         else:
--> 164             output = module._old_forward(*args, **kwargs)
    165         return module._hf_hook.post_forward(module, output)
    166 

[~/.cache/huggingface/modules/transformers_modules/THUDM/chatglm2-6b/7fabe56db91e085c9c027f56f1c654d137bdba40/modeling_chatglm.py](https://localhost:8080/#) in forward(self, input_ids, position_ids, attention_mask, full_attention_mask, past_key_values, inputs_embeds, use_cache, output_hidden_states, return_dict)
    802         return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    803 
--> 804         batch_size, seq_length = input_ids.shape
    805 
    806         if inputs_embeds is None:

AttributeError: 'NoneType' object has no attribute 'shape'

While forwarding with input ids is fine, like:

input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt').cuda()
embeddings = model.get_input_embeddings()(input_ids)
# model(inputs_embeds=embeddings)
model(input_ids=input_ids)

Expected Behavior

Forwarding with input ids or embeddings should have the same behavior.

Steps To Reproduce

Run the code:

tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True,device_map="auto",load_in_8bit=True)
model = model.eval()

text = "1+1=?"
input_ids = tokenizer.encode(text + tokenizer.eos_token, return_tensors='pt').cuda()
embeddings = model.get_input_embeddings()(input_ids)
model(inputs_embeds=embeddings)

Environment

Google Colab

Anything else?

No response

zhi-xuan-chen commented 6 months ago

I have the same issue, do you solve it?