huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.19k stars 26.6k forks source link

`attention_mask` must be in the same device as model? #31975

Closed fahadh4ilyas closed 1 month ago

fahadh4ilyas commented 2 months ago

System Info

Transformer version: 4.42.4 Platform: Ubuntu Python version: 3.10.14

Who can help?

@ArthurZucker @gante

Information

Tasks

Reproduction

Here is an example code:

import torch
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B', device_map='auto', attn_implementation='flash_attention_2', token='you-hf-token')

with torch.inference_mode():
    model.generate(input_ids=torch.tensor([[1,2]], dtype=torch.long, device='cuda:0'), attention_mask=torch.tensor([[1,1]], dtype=torch.long))

Expected behavior

Model generate without issue. I never have this issue when I'm using old transformers. But, now this error shows up

    114 device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
    115 with torch.autocast(device_type=device_type, enabled=False):
--> 116     freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
    117     emb = torch.cat((freqs, freqs), dim=-1)
    118     cos = emb.cos()

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)
gante commented 2 months ago

Hi @fahadh4ilyas 👋

Pytorch doesn't automatically move tensors across devices -- the model and all its inputs must be manually moved to the desired device. In your example, attention_mask doesn't have the device argument.

See the basic generation example here 🤗

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.