huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135k stars 27.01k forks source link

Multi-GPU setup: indices should be either on cpu or on the same device as the indexed tensor (cuda:1) #33147

Open justnoxx opened 2 months ago

justnoxx commented 2 months ago

System Info

python version: 3.11.9 transformers version: 4.44.2 accelerate version: 0.33.0 torch version: 2.4.0+cu121

Who can help?

@gante

Information

Tasks

Reproduction

Hello!

I have a setup with 8xH100 and I need to run really large models. To get started I went through your official example, it is https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference

First of all, there is a typo in there:

from mingpt.bpe import BPETokenizer
tokenizer = BPETokenizer()
inputs = tokenizer("Hello, my name is").to(0)

outputs = model.generate(x1, max_new_tokens=10, do_sample=False)[0]
tokenizer.decode(outputs.cpu().squeeze())

There is no x1 variable. Example from this guide works well.

However, I tried to do the same using Mistral models, so I adapted code form an example to run it with Mistral model:

from huggingface_hub import snapshot_download
from accelerate import init_empty_weights
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import load_checkpoint_and_dispatch

checkpoint = 'mistralai/Mixtral-8x22B-Instruct-v0.1'

weights_location = snapshot_download(repo_id=checkpoint, cache_dir='./cache')
model = AutoModelForCausalLM.from_pretrained(checkpoint)

from accelerate import load_checkpoint_and_dispatch

model = load_checkpoint_and_dispatch(
    model, checkpoint=weights_location, device_map="auto", no_split_module_classes=['Block']
)

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
inputs = tokenizer("Hello, my name is", return_tensors="pt").to(0)
outputs = model.generate(inputs['input_ids'], max_new_tokens=10, do_sample=False)[0]

This code fails on the last line giving the following exception:

RuntimeError                              Traceback (most recent call last)
Cell In[37], line 1
----> 1 outputs = model.generate(inputs['input_ids'], max_new_tokens=10, do_sample=False)[0]
      2 tokenizer.decode(outputs.cpu().squeeze())

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2016     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2017         input_ids=input_ids,
   2018         expand_size=generation_config.num_return_sequences,
   2019         is_encoder_decoder=self.config.is_encoder_decoder,
   2020         **model_kwargs,
   2021     )
   2023     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2024     result = self._sample(
   2025         input_ids,
   2026         logits_processor=prepared_logits_processor,
   2027         logits_warper=prepared_logits_warper,
   2028         stopping_criteria=prepared_stopping_criteria,
   2029         generation_config=generation_config,
   2030         synced_gpus=synced_gpus,
   2031         streamer=streamer,
   2032         **model_kwargs,
   2033     )
   2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2036     # 11. prepare logits warper
   2037     prepared_logits_warper = (
   2038         self._get_logits_warper(generation_config, device=input_ids.device)
   2039         if generation_config.do_sample
   2040         else None
   2041     )

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/generation/utils.py:2982, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2979 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   2981 # forward pass to get next token
-> 2982 outputs = self(**model_inputs, return_dict=True)
   2984 if synced_gpus and this_peer_finished:
   2985     continue  # don't waste resources running the code we don't need

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:1274, in MixtralForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict, cache_position)
   1271 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1273 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1274 outputs = self.model(
   1275     input_ids=input_ids,
   1276     attention_mask=attention_mask,
   1277     position_ids=position_ids,
   1278     past_key_values=past_key_values,
   1279     inputs_embeds=inputs_embeds,
   1280     use_cache=use_cache,
   1281     output_attentions=output_attentions,
   1282     output_hidden_states=output_hidden_states,
   1283     output_router_logits=output_router_logits,
   1284     return_dict=return_dict,
   1285     cache_position=cache_position,
   1286 )
   1288 hidden_states = outputs[0]
   1289 logits = self.lm_head(hidden_states)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:1068, in MixtralModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, output_router_logits, return_dict, cache_position)
   1056     layer_outputs = self._gradient_checkpointing_func(
   1057         decoder_layer.__call__,
   1058         hidden_states,
   (...)
   1065         cache_position,
   1066     )
   1067 else:
-> 1068     layer_outputs = decoder_layer(
   1069         hidden_states,
   1070         attention_mask=causal_mask,
   1071         position_ids=position_ids,
   1072         past_key_value=past_key_values,
   1073         output_attentions=output_attentions,
   1074         output_router_logits=output_router_logits,
   1075         use_cache=use_cache,
   1076         cache_position=cache_position,
   1077     )
   1079 hidden_states = layer_outputs[0]
   1081 if use_cache:

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:812, in MixtralDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, output_router_logits, use_cache, cache_position, **kwargs)
    810 residual = hidden_states
    811 hidden_states = self.post_attention_layernorm(hidden_states)
--> 812 hidden_states, router_logits = self.block_sparse_moe(hidden_states)
    813 hidden_states = residual + hidden_states
    815 outputs = (hidden_states,)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/models/mixtral/modeling_mixtral.py:738, in MixtralSparseMoeBlock.forward(self, hidden_states)
    733 idx, top_x = torch.where(expert_mask[expert_idx])
    735 # Index the correct hidden states and compute the expert hidden state for
    736 # the current expert. We need to make sure to multiply the output hidden
    737 # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
--> 738 current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
    739 current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
    741 # However `index_add_` only support torch tensors for indexing so we'll use
    742 # the `top_x` tensor here.

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

I think the most important is:

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

There are few additions to that:

  1. I am running this code using Jupyter notebook.
  2. Model seems like properly distributed across all GPUs:
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.28.03              Driver Version: 560.28.03      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          Off |   00000000:0F:00.0 Off |                    0 |
| N/A   32C    P0            111W /  700W |   67576MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          Off |   00000000:2D:00.0 Off |                    0 |
| N/A   37C    P0            115W /  700W |   68344MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          Off |   00000000:44:00.0 Off |                    0 |
| N/A   31C    P0            110W /  700W |   68344MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          Off |   00000000:5B:00.0 Off |                    0 |
| N/A   36C    P0            116W /  700W |   68552MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   4  NVIDIA H100 80GB HBM3          Off |   00000000:89:00.0 Off |                    0 |
| N/A   32C    P0            111W /  700W |   68192MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   5  NVIDIA H100 80GB HBM3          Off |   00000000:A8:00.0 Off |                    0 |
| N/A   35C    P0            119W /  700W |   68344MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   6  NVIDIA H100 80GB HBM3          Off |   00000000:C0:00.0 Off |                    0 |
| N/A   40C    P0            116W /  700W |   68360MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   7  NVIDIA H100 80GB HBM3          Off |   00000000:D8:00.0 Off |                    0 |
| N/A   32C    P0            110W /  700W |   67568MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    701575      C   ...yenv/versions/3.11.9/bin/python3.11      67566MiB |
|    1   N/A  N/A    701575      C   ...yenv/versions/3.11.9/bin/python3.11      68334MiB |
|    2   N/A  N/A    701575      C   ...yenv/versions/3.11.9/bin/python3.11      68334MiB |
|    3   N/A  N/A    701575      C   ...yenv/versions/3.11.9/bin/python3.11      68542MiB |
|    4   N/A  N/A    701575      C   ...yenv/versions/3.11.9/bin/python3.11      68182MiB |
|    5   N/A  N/A    701575      C   ...yenv/versions/3.11.9/bin/python3.11      68334MiB |
|    6   N/A  N/A    701575      C   ...yenv/versions/3.11.9/bin/python3.11      68350MiB |
|    7   N/A  N/A    701575      C   ...yenv/versions/3.11.9/bin/python3.11      67558MiB |
+-----------------------------------------------------------------------------------------+

However, I found easier step of reproduction with smaller model, it uses a bit different approach, but the result is the same:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from accelerate import infer_auto_device_map, dispatch_model

# tried with smaller mistral, result is the same
model_name = 'meta-llama/Meta-Llama-Guard-2-8B'

model_name = 'meta-llama/Meta-Llama-Guard-2-8B'

token = 'my-hf-token'

tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=token)

device_map = infer_auto_device_map(
    model,
    max_memory={i: "8GiB" for i in range(torch.cuda.device_count())}
)

model2 = dispatch_model(model, device_map=device_map)

# Prepare input messages
messages = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing well, thank you! How can I assist you today?"},
    {"role": "user", "content": "Can you tell me about the weather today?"}
]

# Concatenate the messages into a single string
conversation = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])

inputs = tokenizer(conversation, return_tensors="pt")
# I tried also
# inputs = tokenizer(conversation, return_tensors="pt").to('cuda:0')
# and
# inputs = tokenizer(conversation, return_tensors="pt").to('cuda:1')
# It is the same
# and even like this:
# first_device = list(device_map.values())[0]
# inputs = inputs.to(f'cuda:{first_device}')
# still the same

output = model.generate(
    inputs["input_ids"],
    max_length=150,  # Adjust according to your needs
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    do_sample=True,
    top_k=50,
    top_p=0.95,
)

The result is the same:

RuntimeError                              Traceback (most recent call last)
Cell In[21], line 2
      1 # Generate text with the model
----> 2 output = model.generate(
      3     inputs["input_ids"],
      4     max_length=150,  # Adjust according to your needs
      5     num_return_sequences=1,
      6     no_repeat_ngram_size=2,
      7     do_sample=True,
      8     top_k=50,
      9     top_p=0.95,
     10 )
     12 # Decode the output to readable text
     13 generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2016     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2017         input_ids=input_ids,
   2018         expand_size=generation_config.num_return_sequences,
   2019         is_encoder_decoder=self.config.is_encoder_decoder,
   2020         **model_kwargs,
   2021     )
   2023     # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2024     result = self._sample(
   2025         input_ids,
   2026         logits_processor=prepared_logits_processor,
   2027         logits_warper=prepared_logits_warper,
   2028         stopping_criteria=prepared_stopping_criteria,
   2029         generation_config=generation_config,
   2030         synced_gpus=synced_gpus,
   2031         streamer=streamer,
   2032         **model_kwargs,
   2033     )
   2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2036     # 11. prepare logits warper
   2037     prepared_logits_warper = (
   2038         self._get_logits_warper(generation_config, device=input_ids.device)
   2039         if generation_config.do_sample
   2040         else None
   2041     )

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/generation/utils.py:2982, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
   2979 model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
   2981 # forward pass to get next token
-> 2982 outputs = self(**model_inputs, return_dict=True)
   2984 if synced_gpus and this_peer_finished:
   2985     continue  # don't waste resources running the code we don't need

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    167         output = module._old_forward(*args, **kwargs)
    168 else:
--> 169     output = module._old_forward(*args, **kwargs)
    170 return module._hf_hook.post_forward(module, output)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:1189, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1186 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1188 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1189 outputs = self.model(
   1190     input_ids=input_ids,
   1191     attention_mask=attention_mask,
   1192     position_ids=position_ids,
   1193     past_key_values=past_key_values,
   1194     inputs_embeds=inputs_embeds,
   1195     use_cache=use_cache,
   1196     output_attentions=output_attentions,
   1197     output_hidden_states=output_hidden_states,
   1198     return_dict=return_dict,
   1199     cache_position=cache_position,
   1200 )
   1202 hidden_states = outputs[0]
   1203 if self.config.pretraining_tp > 1:

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:1001, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    989     layer_outputs = self._gradient_checkpointing_func(
    990         decoder_layer.__call__,
    991         hidden_states,
   (...)
    998         position_embeddings,
    999     )
   1000 else:
-> 1001     layer_outputs = decoder_layer(
   1002         hidden_states,
   1003         attention_mask=causal_mask,
   1004         position_ids=position_ids,
   1005         past_key_value=past_key_values,
   1006         output_attentions=output_attentions,
   1007         use_cache=use_cache,
   1008         cache_position=cache_position,
   1009         position_embeddings=position_embeddings,
   1010     )
   1012 hidden_states = layer_outputs[0]
   1014 if use_cache:

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:750, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
    748 residual = hidden_states
    749 hidden_states = self.post_attention_layernorm(hidden_states)
--> 750 hidden_states = self.mlp(hidden_states)
    751 hidden_states = residual + hidden_states
    753 outputs = (hidden_states,)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:309, in LlamaMLP.forward(self, x)
    307     down_proj = sum(down_proj)
    308 else:
--> 309     down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
    311 return down_proj

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cuda:1!

Since it happens not only on Mistral models, I believe it is either something wrong in my code or there is a bug in the library, please help me to find out. Thanks.

Expected behavior

Inference should work on multiple GPU devices.

I can provide any additional information and try some adjustments.

ArthurZucker commented 2 months ago

cc @muellerzr and @SunMarc

justnoxx commented 2 months ago

FYI, just tested on code from the main branch, it is the same.

Thanks.

justnoxx commented 2 months ago

@muellerzr , @SunMarc can you please take a look?

SunMarc commented 2 months ago

Hey @justnoxx, you can use directly the accelerate integration in transformers to load the model on multi-gpu. You just need to pass device_map="auto" in from_pretrained:

tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=token, device_map="auto")

Let me know if this works !

justnoxx commented 2 months ago

Thanks, @SunMarc , it worked, but it use TF now.

There is my code:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from accelerate import infer_auto_device_map, dispatch_model

# model_name = "mistralai/Mistral-Nemo-Base-2407"
model_name = '''mistralai/Mixtral-8x22B-Instruct-v0.1'''
# model_name = '''mistralai/Mistral-Large-Instruct-2407'''
# model_name = '''tiiuae/falcon-180B'''
# model_name = 'meta-llama/Meta-Llama-Guard-2-8B'

token = 'my_token'

# Prepare input messages
messages = [
    {"role": "user", "content": "Hello, how are you?"},
    {"role": "assistant", "content": "I'm doing well, thank you! How can I assist you today?"},
    {"role": "user", "content": "Can you tell me about the weather today?"}
]

# Concatenate the messages into a single string
conversation = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])

# Tokenize the input string
inputs = tokenizer(conversation, return_tensors="pt")

# Generate text with the model
output = model.generate(
    inputs["input_ids"],
    max_length=1000,  # Adjust according to your needs
    num_return_sequences=1,
    no_repeat_ngram_size=2,
    do_sample=True,
    top_k=50,
    top_p=0.95,
)

# Decode the output to readable text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)

It gives me the text generated, which is fine, but now I see some Tensorflow specific warnings:

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
/home/ubuntu/.pyenv/versions/3.11.9/lib/python3.11/site-packages/transformers/generation/utils.py:1900: UserWarning: You are calling .generate() with the `input_ids` being on a device type different than your model's device. `input_ids` is on cpu, whereas the model is on cuda. You may experience unexpected behaviors or slower generation. Please make sure that you have put `input_ids` to the correct device by calling for example input_ids = input_ids.to('cuda') before running `.generate()`.
  warnings.warn(
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
2024-09-02 19:33:11.723506: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-02 19:33:11.735417: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-02 19:33:11.749112: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-02 19:33:11.752901: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-02 19:33:11.763740: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 AVX512_FP16 AVX_VNNI AMX_TILE AMX_INT8 AMX_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-09-02 19:33:12.373910: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT

Maybe this issue is PyTorch specific? Is there a way to specify which framework to use when device_map='auto' is set during from_pretrained?

Thanks.

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

SunMarc commented 1 month ago

Maybe this issue is PyTorch specific? Is there a way to specify which framework to use when device_map='auto' is set during from_pretrained?

Thanks.

When using AutoModelForCausalLM to load models, it will use pytorch. If you want to use tensorflow instead, you need to use TFAutoModelForCausalLM. So the answer is that device_map='auto' do not have influence on the framework you use. Just uninstall Tensorflow to not see these warnings. Is the issue fixed, so that we can close this issue ?

AishPattabi commented 1 month ago

I hit this issue this week. I am training a large model Gemma-2b-IT in 4*A100 setup. I am using device_map='auto'. Training runs fine except during evaluation stage, I get the error

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)

Here is the snipped of my code:

model = AutoModelForCausalLM.from_pretrained(
      "google/gemma-2-2b-it"
        trust_remote_code=True,
        device_map="auto",
        attn_implementation='eager',
        torch_dtype=torch.bfloat16,
        token=access_token
    )
    model.enable_input_require_grads()

 trainer = CustomTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        formatting_func=formatting_func,
        packing=False,
        args=SFTConfig(
            max_seq_length=MAX_SEQ_LENGTH,
            output_dir=f"{output_dir_name}-training",
            log_level="debug",
            num_train_epochs=epoch,
            learning_rate=learning_rate,
            save_strategy=SAVE_STRATEGY,
            evaluation_strategy=EVAL_STRATEGY,
            logging_strategy=LOGGING_STRATEGY,
            logging_steps=LOGGING_STEPS,
            bf16=True,
            lr_scheduler_type="cosine",
            warmup_ratio=0.1,
            per_device_train_batch_size=2,
            per_device_eval_batch_size=2,           
            gradient_accumulation_steps=2,
            gradient_checkpointing=True,
            optim="paged_adamw_8bit"
        ),
    )

    trainer.train()
AishPattabi commented 1 month ago

Full stack trace:

RuntimeError                              Traceback (most recent call last)
Cell In[21], line 1
----> 1 train("google/gemma-2-2b-it", 2e-4, 3, peft=True)

Cell In[20], line 59, in train(base_model, learning_rate, epoch, peft)
     27     gradient_checkpointing = False
     29 trainer = CustomTrainer(
     30     model=model,
     31     tokenizer=tokenizer,
   (...)
     56     ),
     57 )
---> 59 trainer.train()
     61 trainer.save_model(output_dir_name)
     62 log_history = pd.DataFrame.from_dict(trainer.state.log_history)

File /opt/conda/lib/python3.10/site-packages/trl/trainer/sft_trainer.py:434, in SFTTrainer.train(self, *args, **kwargs)
    431 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
    432     self.model = self._trl_activate_neftune(self.model)
--> 434 output = super().train(*args, **kwargs)
    436 # After training we make sure to retrieve back the original forward pass method
    437 # for the embedding layer by removing the forward post hook.
    438 if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2052, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2050         hf_hub_utils.enable_progress_bars()
   2051 else:
-> 2052     return inner_training_loop(
   2053         args=args,
   2054         resume_from_checkpoint=resume_from_checkpoint,
   2055         trial=trial,
   2056         ignore_keys_for_eval=ignore_keys_for_eval,
   2057     )

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2487, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2484     self.control.should_training_stop = True
   2486 self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
-> 2487 self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2489 if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
   2490     if is_torch_xla_available():
   2491         # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2915, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
   2913 metrics = None
   2914 if self.control.should_evaluate:
-> 2915     metrics = self._evaluate(trial, ignore_keys_for_eval)
   2917 if self.control.should_save:
   2918     self._save_checkpoint(model, trial, metrics=metrics)

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:2872, in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
   2871 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 2872     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
   2873     self._report_to_hp_search(trial, self.state.global_step, metrics)
   2875     # Run delayed LR scheduler now that metrics are populated

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:3868, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
   3865 start_time = time.time()
   3867 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 3868 output = eval_loop(
   3869     eval_dataloader,
   3870     description="Evaluation",
   3871     # No point gathering the predictions if there are no metrics, otherwise we defer to
   3872     # self.args.prediction_loss_only
   3873     prediction_loss_only=True if self.compute_metrics is None else None,
   3874     ignore_keys=ignore_keys,
   3875     metric_key_prefix=metric_key_prefix,
   3876 )
   3878 total_batch_size = self.args.eval_batch_size * self.args.world_size
   3879 if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:4061, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
   4058         batch_size = observed_batch_size
   4060 # Prediction step
-> 4061 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
   4062 main_input_name = getattr(self.model, "main_input_name", "input_ids")
   4063 inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:4279, in Trainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
   4277 if has_labels or loss_without_labels:
   4278     with self.compute_loss_context_manager():
-> 4279         loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
   4280     loss = loss.mean().detach()
   4282     if isinstance(outputs, dict):

File /opt/conda/lib/python3.10/site-packages/transformers/trainer.py:3532, in Trainer.compute_loss(self, model, inputs, return_outputs)
   3530 else:
   3531     labels = None
-> 3532 outputs = model(**inputs)
   3533 # Save past state if it exists
   3534 # TODO: this needs to be fixed and made cleaner later.
   3535 if self.args.past_index >= 0:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:820, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
    819 def forward(*args, **kwargs):
--> 820     return model_forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/utils/operations.py:808, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
    807 def __call__(self, *args, **kwargs):
--> 808     return convert_to_fp32(self.model_forward(*args, **kwargs))

File /opt/conda/lib/python3.10/site-packages/torch/amp/autocast_mode.py:43, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
     40 @functools.wraps(func)
     41 def decorate_autocast(*args, **kwargs):
     42     with autocast_instance:
---> 43         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/peft/peft_model.py:1577, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1575     with self._enable_peft_forward_hooks(**kwargs):
   1576         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1577         return self.base_model(
   1578             input_ids=input_ids,
   1579             attention_mask=attention_mask,
   1580             inputs_embeds=inputs_embeds,
   1581             labels=labels,
   1582             output_attentions=output_attentions,
   1583             output_hidden_states=output_hidden_states,
   1584             return_dict=return_dict,
   1585             **kwargs,
   1586         )
   1588 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1589 if attention_mask is not None:
   1590     # concat prompt attention mask

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:188, in BaseTuner.forward(self, *args, **kwargs)
    187 def forward(self, *args: Any, **kwargs: Any):
--> 188     return self.model.forward(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:1047, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
   1045 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1046 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1047 outputs = self.model(
   1048     input_ids=input_ids,
   1049     attention_mask=attention_mask,
   1050     position_ids=position_ids,
   1051     past_key_values=past_key_values,
   1052     inputs_embeds=inputs_embeds,
   1053     use_cache=use_cache,
   1054     output_attentions=output_attentions,
   1055     output_hidden_states=output_hidden_states,
   1056     return_dict=return_dict,
   1057     cache_position=cache_position,
   1058 )
   1060 hidden_states = outputs[0]
   1061 if labels is None and not is_torchdynamo_compiling():

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:890, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
    879     layer_outputs = self._gradient_checkpointing_func(
    880         decoder_layer.__call__,
    881         hidden_states,
   (...)
    887         cache_position,
    888     )
    889 else:
--> 890     layer_outputs = decoder_layer(
    891         hidden_states,
    892         attention_mask=causal_mask,
    893         position_ids=position_ids,
    894         past_key_value=past_key_values,
    895         output_attentions=output_attentions,
    896         use_cache=use_cache,
    897         cache_position=cache_position,
    898     )
    900 hidden_states = layer_outputs[0]
    902 if output_attentions:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:604, in Gemma2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    601 hidden_states = self.input_layernorm(hidden_states)
    603 # Self Attention
--> 604 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    605     hidden_states=hidden_states,
    606     attention_mask=attention_mask,
    607     position_ids=position_ids,
    608     past_key_value=past_key_value,
    609     output_attentions=output_attentions,
    610     use_cache=use_cache,
    611     cache_position=cache_position,
    612 )
    613 hidden_states = self.post_attention_layernorm(hidden_states)
    614 hidden_states = residual + hidden_states

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/gemma2/modeling_gemma2.py:237, in Gemma2Attention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
    229 if past_key_value is not None:
    230     # sin and cos are specific to RoPE models; cache_position needed for the static cache
    231     cache_kwargs = {
    232         "sin": sin,
    233         "cos": cos,
    234         "sliding_window": self.sliding_window,
    235         "cache_position": cache_position,
    236     }
--> 237     key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
    239 key_states = repeat_kv(key_states, self.num_key_value_groups)
    240 value_states = repeat_kv(value_states, self.num_key_value_groups)

File /opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1672, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
   1669 else:
   1670     update_fn = self._static_update
-> 1672 return update_fn(
   1673     cache_position,
   1674     layer_idx,
   1675     key_states,
   1676     value_states,
   1677     k_out,
   1678     v_out,
   1679     k_out.shape[2],
   1680 )

File /opt/conda/lib/python3.10/site-packages/transformers/cache_utils.py:1635, in HybridCache._sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
   1633 to_shift = cache_position >= max_cache_len - 1
   1634 indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
-> 1635 k_out = k_out[:, :, indices]
   1636 v_out = v_out[:, :, indices]
   1638 k_out[:, :, cache_position] = key_states

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:0)
SunMarc commented 1 month ago

Hi @AishPattabi, please open a new issue as your issue is not the exactly same as the original post. Before you open a new issue, try to install transformers from source and run your code again. We fixed a couple of issues related to kv-cache and multi-device recently.

ryusaeba commented 1 month ago

@SunMarc, Will be there has a patch release for v4.45 series?

SunMarc commented 1 month ago

I'm pretty sure it is in the latest transformers, try it !

ryusaeba commented 3 weeks ago

Thank you @SunMarc. I tried with v4.45.2 and the issue still persist. Will give it a try with latest transfromers.

SunMarc commented 3 weeks ago

Hey @ryusaeba, please share a reproducer with its traceback if the issue still persist !

rangehow commented 1 week ago

any update on this? met this on transformers 4.46.2. I use self specified device_map, this same code with qwen is okay

from niuload import balanced_load
import torch
from transformers import AutoTokenizer

from config import *

# 初始化tokenizer并准备输入
tokenizer = AutoTokenizer.from_pretrained(model_dir.get("qwen2.5-7b"))
texts = [
    "Today is a beautiful day!",
    "How are you doing?"
]

# 创建批处理的输入
batch_encoding = tokenizer(
    texts,
    padding='max_length',
    max_length=128,
    truncation=True,
    return_tensors='pt'
)

input_ids = batch_encoding['input_ids']
attention_mask = batch_encoding['attention_mask']

# 加载模型并进行推理
teacher_model = balanced_load(
    model_dir=model_dir.get("qwen2.5-7b"),
    num_devices=2,
    ratio=[0.5,1],
    devices_idx=[0,1],
)

with torch.inference_mode():
    teacher_outputs = teacher_model(
        input_ids=input_ids.to(teacher_model.device),
        attention_mask=attention_mask.to(teacher_model.device),
    )
ryusaeba commented 3 days ago

The issue was happened at Gemma-2. I will see whether we can prepare a scripts for reproducible.

ryusaeba commented 2 days ago

@SunMarc

The issue is still persist. Please see the following code and help on this issue.

CODE

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
import pdb

MODEL_PATH='/llm_data2/huggingface/models/google/git_version/gemma-2-2b-it'
texts = "Today is a beautiful day!"
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH,
    torch_dtype = 'bfloat16',
    low_cpu_mem_usage = True,
    device_map = 'auto',
)
tokenized_txt = tokenizer(
    texts,
    return_tensors='pt'
)
model(tokenized_txt['input_ids'])

ERROR w/ v4.45.2

  File "<llm>/transformers/src/transformers/models/gemma2/modeling_gemma2.py", line 248, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  File "<llm>/transformers/src/transformers/cache_utils.py", line 1672, in update
    return update_fn(
  File "<llm>/transformers/src/transformers/cache_utils.py", line 1649, in _static_update
    k_out[:, :, cache_position] = key_states
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!

ERROR w/ 4.46.2

Same as above.