huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
Apache License 2.0
132.11k stars 26.32k forks source link

GPTNeoX position_ids not defined #22758

Closed murthyrudra closed 1 year ago

murthyrudra commented 1 year ago

System Info

Who can help?

@ArthurZucker @stas00

Hi, I am performing inference using GPT-NeoX 20B model using greedy search. Without deepspeed the text generation works fine. However, when I use deepspeed for inference, I am getting the following error

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮                                                                                                                
│ ~/examplesinference/ in                                                     │                                                                                                                
│ <module>                                                                                         │                                                                                                                
│                                                                                                  │                                                                                                                
│   294                                                                                            │                                                                                                                
│   295                                                                                            │                                                                                                                
│   296 if __name__ == "__main__":                                                                 │                                                                                                                
│ ❱ 297 │   main()                                                                                 │                                                                                                                
│   298                                                                                            │                                                                                                                
│                                                                                                  │                                                                                                                
│ ~/examplesinference/ in main                                                │                                                                                                                
│                                                                                                  │                                                                                                                
│   268 │   │   │   │   + "\nQ: "                                                                  │                                                                                                                
│   269 │   │   │   )                                                                              │                                                                                                                
│   270 │   │   new_prompt = prompt + d["question"] + "\nA:"                                       │                                                                                                                
│ ❱ 271 │   │   output = predict_text_greedy(                                                      │                                                                                                                
│   272 │   │   │   model,                                                                         │                                                                                                                
│   273 │   │   │   tokenizer,                                                                     │                                                                                                                
│   274 │   │   │   new_prompt,                                                                    │                                                                                                                
│                                                                                                  │                                                                                                                
│ ~/examplesinference/ in                                                      │                                                                                                                
│ predict_text_greedy                                                                              │                                                                                                                
│                                                                                                  │                                                                                                                
│    95 │                                                                                          │                                                                                                                
│    96 │   model.eval()                                                                           │                                                                                                                
│    97 │   with torch.no_grad():                                                                  │                                                                                                                
│ ❱  98 │   │   generated_ids = model.generate(                                                    │
│    97 │   with torch.no_grad():                                                                  │                                                                                                      [64/49095]
│ ❱  98 │   │   generated_ids = model.generate(                                                    │                                                                                                                
│    99 │   │   │   input_ids,                                                                     │                                                                                                                
│   100 │   │   │   max_new_tokens=50,                                                             │                                                                                                                
│   101 │   │   │   use_cache=use_cache,                                                           │                                                                                                                
│                                                                                                  │                                                                                                                
│ ~/my_envlib/python3.9/site-packages/deepspeed/inference/ in                         │                                                                                                                
│ _generate                                                                                        │                                                                                                                
│                                                                                                  │                                                                                                                
│   585 │   │   │   │   "add your request to:   │                                                                                                                
│   586 │   │   │   )                                                                              │                                                                                                                
│   587 │   │                                                                                      │                                                                                                                
│ ❱ 588 │   │   return self.module.generate(*inputs, **kwargs)                                     │                                                                                                                
│   589                                                                                            │                                                                                                                
│                                                                                                  │                                                                                                                
│ ~/my_envlib/python3.9/site-packages/torch/utils/ in                            │                                                                                                                
│ decorate_context                                                                                 │                                                                                                                
│                                                                                                  │                                                                                                                
│   112 │   @functools.wraps(func)                                                                 │                                                                                                                
│   113 │   def decorate_context(*args, **kwargs):                                                 │                                                                                                                
│   114 │   │   with ctx_factory():                                                                │                                                                                                                
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │                                                                                                                
│   116 │                                                                                          │                                                                                                                
│   117 │   return decorate_context                                                                │                                                                                                                
│   118                                                                                            │                                                                                                                
│                                                                                                  │                                                                                                                
│ ~/my_envlib/python3.9/site-packages/transformers/generation/ in                     │                                                                                                                
│ generate                                                                                         │                                                                                                                
│                                                                                                  │                                                                                                                
│   1434 │   │   │   │   )                                                                         │                                                                                                                
│   1435 │   │   │                                                                                 │                                                                                                                
│   1436 │   │   │   # 11. run greedy search                                                       │                                                                                                                
│ ❱ 1437 │   │   │   return self.greedy_search(                                                    │                                                                                                                
│   1438 │   │   │   │   input_ids,                                                                │                                                                                                                
│   1439 │   │   │   │   logits_processor=logits_processor,                                        │                                                                                                                
│   1440 │   │   │   │   stopping_criteria=stopping_criteria,                                      │                                                                                                                
│                                                                                                  │                                                                                                                
│ ~/my_envlib/python3.9/site-packages/transformers/generation/ in                     │                                                                                                                
│ greedy_search                                                                                    │                                                                                                                
│                                                                                                  │                                                                                                                
│   2245 │   │   │   model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)  │                                                                                                                
│   2246 │   │   │                                                                                 │                                                                                                                
│   2247 │   │   │   # forward pass to get next token                                              │                                                                                                                
│ ❱ 2248 │   │   │   outputs = self(                                                               │                                                                                                                
│   2249 │   │   │   │   **model_inputs,                                                           │                                                                                                                
│   2250 │   │   │   │   return_dict=True,                                                         │                                                                                                                
│   2251 │   │   │   │   output_attentions=output_attentions,                                      │                                                                                                                
│                                                                                                  │                                                                                                                
│ ~/my_envlib/python3.9/site-packages/torch/nn/modules/ in                           │                                                                                                                
│ _call_impl                                                                                       │                                                                                                                
│                                                                                                  │                                                                                                                
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │                                                                                                      [12/49095]
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │                                                                                                                
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │                                                                                                                
│   1502 │   │   # Do not call functions when jit is used                                          │                                                                                                                
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │                                                                                                                
│   1504 │   │   backward_pre_hooks = []                                                           │                                                                                                                
│                                                                                                  │                                                                                                                
│ ~/my_envlib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gp                     │                                                                                                                
│ in forward                                                                         │                                                                                                                
│                                                                                                  │                                                                                                                
│   659 │   │   ```"""                                                                             │                                                                                                                
│   660 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return   │                                                                                                                
│   661 │   │                                                                                      │                                                                                                                
│ ❱ 662 │   │   outputs = self.gpt_neox(                                                           │
│   663 │   │   │   input_ids,                                                                     │
│   664 │   │   │   attention_mask=attention_mask,                                                 │
│   665 │   │   │   position_ids=position_ids,                                                     │
│                                                                                                  │
│ ~/my_envlib/python3.9/site-packages/torch/nn/modules/ in                           │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ ~/my_envlib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gp                     │
│ in forward                                                                         │
│                                                                                                  │
│   550 │   │   │   │   │   head_mask[i],                                                          │
│   551 │   │   │   │   )                                                                          │
│   552 │   │   │   else:                                                                          │
│ ❱ 553 │   │   │   │   outputs = layer(                                                           │
│   554 │   │   │   │   │   hidden_states,                                                         │
│   555 │   │   │   │   │   attention_mask=attention_mask,                                         │
│   556 │   │   │   │   │   position_ids=position_ids,                                             │
│                                                                                                  │
│ ~/my_envlib/python3.9/site-packages/torch/nn/modules/ in                           │
│ _call_impl                                                                                       │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
TypeError: forward() got an unexpected keyword argument 'position_ids'

This is how I am wrapping deepspeed around the model

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

tokenizer.padding_side = "left"

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

reduced_model_name = model_name.split("/")[-1]

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = deepspeed.init_inference(
model, mp_size=world_size, dtype=torch.float32, replace_with_kernel_inject=True




from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import deepspeed

model_name = 'EleutherAI/gpt-neox-20b'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

tokenizer.padding_side = "left"

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

reduced_model_name = model_name.split("/")[-1]

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = deepspeed.init_inference(
    model, mp_size=world_size, dtype=torch.float32, replace_with_kernel_inject=True

input_ids = tokenizer('The quick brown fox jumped over the lazy dog', return_tensors="pt")
    dtype=torch.long, device=device

with torch.no_grad():
    generated_ids = model.generate(
    preds = [
            g, skip_special_tokens=True, clean_up_tokenization_spaces=True
        for g in generated_ids

Expected behavior

There should be no difference whether I wrap deepspeed around the model or not.

stas00 commented 1 year ago

transformers isn't involved with deepspeed's inference engine, other than being used by it indirectly, so please refile at Thank you.