huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.73k stars 26.22k forks source link

MPS support broken for T5 models #31737

Closed abdulfatir closed 1 month ago

abdulfatir commented 2 months ago

System Info

Who can help?

@zucchini-nlp @ArthurZucker

Information

Tasks

Reproduction

The following breaks with an error on Apple Silicon Macs. It works on older versions (e.g., transformers~=4.39.0).

import torch
from transformers import GenerationConfig
from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained(
    "google/t5-efficient-tiny", device_map="mps"
)
input_ids = torch.tensor([[4, 5, 6, 6, 7]], device="mps")
model.generate(
    input_ids=input_ids,
    generation_config=GenerationConfig(do_sample=True),
)

Error:

NotImplementedError                       Traceback (most recent call last)
Cell In[1], line 9
      5 model = T5ForConditionalGeneration.from_pretrained(
      6     \"google/t5-efficient-tiny\", device_map=\"mps\"
      7 )
      8 input_ids = torch.tensor([[4, 5, 6, 6, 7]], device=\"mps\")
----> 9 model.generate(
     10     input_ids=input_ids,
     11     generation_config=GenerationConfig(do_sample=True),
     12 )

File ~/miniconda3/envs/chronos-forecasting/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniconda3/envs/chronos-forecasting/lib/python3.10/site-packages/transformers/generation/utils.py:1664, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   1661 batch_size = inputs_tensor.shape[0]
   1663 device = inputs_tensor.device
-> 1664 self._prepare_special_tokens(generation_config, kwargs_has_attention_mask, device=device)
   1666 # decoder-only models must use left-padding for batched generation.
   1667 if not self.config.is_encoder_decoder and not is_torchdynamo_compiling():
   1668     # If `input_ids` was given, check if the last id in any sequence is `pad_token_id`
   1669     # Note: If using, `inputs_embeds` this check does not work, because we want to be more hands-off.

File ~/miniconda3/envs/chronos-forecasting/lib/python3.10/site-packages/transformers/generation/utils.py:1513, in GenerationMixin._prepare_special_tokens(self, generation_config, kwargs_has_attention_mask, device)
   1510     logger.warning(f\"Setting `pad_token_id` to `eos_token_id`:{pad_token_id} for open-end generation.\")
   1512 # we can't infer attn mask if pad token is set to be eos token in model's generation config
-> 1513 if eos_token_id is not None and torch.isin(elements=eos_token_id, test_elements=pad_token_id).any():
   1514     if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
   1515         logger.warning_once(
   1516             \"The attention mask is not set and cannot be inferred from input because pad token is same as eos token.\"
   1517             \"As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` \"
   1518             \"to obtain reliable results.\"
   1519         )

NotImplementedError: The operator 'aten::isin.Tensor_Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS."

Expected behavior

Expected the script to work like in older versions.

amyeroberts commented 2 months ago

Hi @abdulfatir, thanks for reporting!

There's a PR open to fix this here: #31695

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

amyeroberts commented 1 month ago

Closing as #31695 is merged in