huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.53k stars 26.68k forks source link

.generate() supports contrastive-search on multi-device? #24634

Closed pfldy2850 closed 1 year ago

pfldy2850 commented 1 year ago

System Info

script

import torch
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM

checkpoint = "EleutherAI/polyglot-ko-12.8b"
tokenizer = AutoTokenizer.from_pretrained(
    checkpoint,
    padding_side="left",
    pad_token_id=0,
)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    torch_dtype=torch.float16,
    device_map="auto", 
    load_in_8bit=False,
    pad_token_id=tokenizer.pad_token_id,
)
model.eval()

tokenized = tokenizer("hi there?", return_tensors='pt')

input_ids = tokenized.input_ids
attention_mask = tokenized.attention_mask

generated = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)

faced messages

when I ran upper script, I was faced following message.

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:6                                                                                    │
│                                                                                                  │
│   3 input_ids = tokenized.input_ids                                                              │
│   4 attention_mask = tokenized.attention_mask                                                    │
│   5                                                                                              │
│ ❱ 6 generated = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)            │
│   7                                                                                              │
│                                                                                                  │
│ /opt/conda/envs/py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in            │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /opt/conda/envs/py3.10/lib/python3.10/site-packages/transformers/generation/utils.py:1544 in     │
│ generate                                                                                         │
│                                                                                                  │
│   1541 │   │   │   if not model_kwargs["use_cache"]:                                             │
│   1542 │   │   │   │   raise ValueError("Contrastive search requires `use_cache=True`")          │
│   1543 │   │   │                                                                                 │
│ ❱ 1544 │   │   │   return self.contrastive_search(                                               │
│   1545 │   │   │   │   input_ids,                                                                │
│   1546 │   │   │   │   top_k=generation_config.top_k,                                            │
│   1547 │   │   │   │   penalty_alpha=generation_config.penalty_alpha,                            │
│                                                                                                  │
│ /opt/conda/envs/py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in            │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /opt/conda/envs/py3.10/lib/python3.10/site-packages/transformers/generation/utils.py:2004 in     │
│ contrastive_search                                                                               │
│                                                                                                  │
│   2001 │   │   │                                                                                 │
│   2002 │   │   │   logit_for_next_step = logits_processor(input_ids, logit_for_next_step)        │
│   2003 │   │   │   logit_for_next_step = logits_warper(input_ids, logit_for_next_step)           │
│ ❱ 2004 │   │   │   next_probs = nn.functional.softmax(logit_for_next_step, dim=-1)               │
│   2005 │   │   │   top_k_probs, top_k_ids = torch.topk(next_probs, dim=-1, k=top_k)              │
│   2006 │   │   │                                                                                 │
│   2007 │   │   │   # Store scores, attentions and hidden_states when required                    │
│                                                                                                  │
│ /opt/conda/envs/py3.10/lib/python3.10/site-packages/torch/nn/functional.py:1843 in softmax       │
│                                                                                                  │
│   1840 │   if dim is None:                                                                       │
│   1841 │   │   dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)                       │
│   1842 │   if dtype is None:                                                                     │
│ ❱ 1843 │   │   ret = input.softmax(dim)                                                          │
│   1844 │   else:                                                                                 │
│   1845 │   │   ret = input.softmax(dim, dtype=dtype)                                             │
│   1846 │   return ret                                                                            │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: "softmax_lastdim_kernel_impl" not implemented for 'Half'

And then, I modified my script to input_ids move to cuda:0 like this:

input_ids = tokenized.input_ids.to("cuda:0")
generated = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)

Finally I met following message:

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ in <module>:6                                                                                    │
│                                                                                                  │
│   3 input_ids = tokenized.input_ids.to("cuda:0")                                                 │
│   4 attention_mask = tokenized.attention_mask.to("cuda:0")                                       │
│   5                                                                                              │
│ ❱ 6 generated = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)            │
│   7                                                                                              │
│                                                                                                  │
│ /opt/conda/envs/py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in            │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /opt/conda/envs/py3.10/lib/python3.10/site-packages/transformers/generation/utils.py:1544 in     │
│ generate                                                                                         │
│                                                                                                  │
│   1541 │   │   │   if not model_kwargs["use_cache"]:                                             │
│   1542 │   │   │   │   raise ValueError("Contrastive search requires `use_cache=True`")          │
│   1543 │   │   │                                                                                 │
│ ❱ 1544 │   │   │   return self.contrastive_search(                                               │
│   1545 │   │   │   │   input_ids,                                                                │
│   1546 │   │   │   │   top_k=generation_config.top_k,                                            │
│   1547 │   │   │   │   penalty_alpha=generation_config.penalty_alpha,                            │
│                                                                                                  │
│ /opt/conda/envs/py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py:115 in            │
│ decorate_context                                                                                 │
│                                                                                                  │
│   112 │   @functools.wraps(func)                                                                 │
│   113 │   def decorate_context(*args, **kwargs):                                                 │
│   114 │   │   with ctx_factory():                                                                │
│ ❱ 115 │   │   │   return func(*args, **kwargs)                                                   │
│   116 │                                                                                          │
│   117 │   return decorate_context                                                                │
│   118                                                                                            │
│                                                                                                  │
│ /opt/conda/envs/py3.10/lib/python3.10/site-packages/transformers/generation/utils.py:2076 in     │
│ contrastive_search                                                                               │
│                                                                                                  │
│   2073 │   │   │   │   # item is either the key or the value matrix                              │
│   2074 │   │   │   │   for item in layer:                                                        │
│   2075 │   │   │   │   │   item = torch.stack(torch.split(item, top_k, dim=0))  # [B, K, num_he  │
│ ❱ 2076 │   │   │   │   │   item = item[range(batch_size), selected_idx, ...]  # [B, num_head, s  │
│   2077 │   │   │   │   │   items += (item,)                                                      │
│   2078 │   │   │   │   new_key_values += (items,)                                                │
│   2079 │   │   │   next_past_key_values = new_key_values                                         │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

transforers-cli env

Copy-and-paste the text below in your GitHub issue and FILL OUT the two last points.

Who can help?

@gante

Information

Tasks

Reproduction

  1. Just run the following generating script on the env multi-device assigned.
import torch
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM

checkpoint = "EleutherAI/polyglot-ko-12.8b"
tokenizer = AutoTokenizer.from_pretrained(
    checkpoint,
    padding_side="left",
    pad_token_id=0,
)
model = AutoModelForCausalLM.from_pretrained(
    checkpoint,
    torch_dtype=torch.float16,
    device_map="auto", 
    load_in_8bit=False,
    pad_token_id=tokenizer.pad_token_id,
)
model.eval()

tokenized = tokenizer("hi there?", return_tensors='pt')

input_ids = tokenized.input_ids
attention_mask = tokenized.attention_mask

generated = model.generate(input_ids, penalty_alpha=0.6, top_k=4, max_length=512)

Expected behavior

I want to obtain the generated text regardless of the outcome.

gante commented 1 year ago

Hey @pfldy2850 👋

I believe I know the solution to your issues, but I don't have a multi-gpu setup. I'm going to open a PR, and then ask you to double-check whether it works :)

gante commented 1 year ago

@pfldy2850 would you be able to test using this PR?

pfldy2850 commented 1 year ago

@gante

Wow! Your outstanding work has successfully resolved the issue. 👍 I have achieved an expected output that I was aiming for.

I would like to use this changes in production. Could you please provide information on the release cycle of this repository?

gante commented 1 year ago

@pfldy2850 awesome! The PR should be merged within 24 hours :)

You have two options, after the PR gets merged:

  1. Wait for the next release, which will probably happen in two or three weeks
  2. Install from main with pip install --upgrade git+https://github.com/huggingface/transformers.git OR replace the requirement on your setup.py/requirements.txt with transformers @ git+https://github.com/huggingface/transformers.git