huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.98k stars 27k forks source link

Shape mismatch when generating with multiple processes #32603

Closed ojh31 closed 1 month ago

ojh31 commented 3 months ago

System Info

Who can help?

@gante @SunMarc @ArthurZucker

Information

Tasks

Reproduction

Run accelerate launch --config_file=accelerate_config.yaml foo.py

foo.py:

from accelerate import Accelerator
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoModelForCausalLM, LlamaTokenizer

NAME = "meta-llama/Llama-2-13b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(
    NAME, 
    padding_side="left", 
)
accelerator = Accelerator()
model = AutoModelForCausalLM.from_pretrained(NAME)
model = accelerator.prepare(model)

paired_input_ids = torch.tensor([
    [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             1, 29961, 25580, 29962,   887,   526,   385,  7395,  6993,  1788,
         29889, 13866,   366,   674,   367,  5429,   278,  1788,  4800, 29889,
           450,  1404,   674,   769,  3896,   263,  4800, 29889,   960,   278,
          1404,  4800,  7087,   278,  1788,  4800, 29892,   736, 18016, 13566,
          3352, 29889,   960,   278,  1404,  4800,   947,   451,  1993,   278,
          1788,  4800, 29892,   736,   360,  1430, 29902,  3352, 29889, 18076,
           487,  3099,  1156,   278,  1404,  4800, 29889,    13,    13,  3924,
          4800, 29901, 15040,    13,  2659,  4800, 29901, 15040,    13,  5634,
            13,    13,    13, 22550, 29901,   518, 29914, 25580, 29962],
            [    2,     2,     2,     2,     2,     2,     1, 29961, 25580, 29962,
           887,   526,   385,  7395,  6993,  1788, 29889, 13866,   366,   674,
           367,  5429,   278,  1788,  4800, 29889,   450,  1404,   674,   769,
          3896,   263,  4800, 29889,   960,   278,  1404,  4800,  7087,   278,
          1788,  4800, 29892,   736, 18016, 13566,  3352, 29889,   960,   278,
          1404,  4800,   947,   451,  1993,   278,  1788,  4800, 29892,   736,
           360,  1430, 29902,  3352, 29889, 18076,   487,  3099,  1156,   278,
          1404,  4800, 29889,    13,    13,  3924,  4800, 29901,  1757, 10582,
           284,    13,  2659,  4800, 29901,  1757, 10582,   284,    13,  5634,
            13,    13,    13, 22550, 29901,   518, 29914, 25580, 29962]
        ]
)
paired_attention_mask = torch.tensor([
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
    [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]
])

paired_dataset = TensorDataset(paired_input_ids, paired_attention_mask)

dataloader = DataLoader(
    dataset=paired_dataset,
    batch_size=1,  # Process one pair at a time
    shuffle=False,
)
dataloader = accelerator.prepare(dataloader)

for batch_input_ids, batch_attention_mask in dataloader:
    with torch.no_grad():
        model.forward(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
    with FSDP.summon_full_params(model, recurse=False):
        outputs = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask, 
            tokenizer=tokenizer,
            synced_gpus=True,
        )

accelerate_config.yaml:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: "no"
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: "no"
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Expected behavior

Should generate text output, but instead throws error

The expanded size of the tensor (105) must match the existing size (104) at non-singleton dimension 3.  Target sizes: [1, 40, 1, 105].  Tensor sizes: [1, 1, 1, 104]
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 648, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 718, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 978, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
    outputs = self.model(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/robust-llm/pairs.py", line 66, in <module>
    outputs = model.generate(
RuntimeError: The expanded size of the tensor (105) must match the existing size (104) at non-singleton dimension 3.  Target sizes: [1, 40, 1, 105].  Tensor sizes: [1, 1, 1, 104]

Hypothesis: In transformers/generation/utils.py::GenerationMixin_sample(), during the while self._has_unfinished_sequences() loop, we continue if synced_gpus and this_peer_finished. This results in not skipping the concatenation of next_tokens to input_ids. Whereas, we keep updating the past_key_value cache in transformers/models/llama/modeling_llama.py::LlamaSdpaAttention.forward(). Therefore, when one process finishes generation before the other, the finished process continues to expand the key-value cache but stops expanding the input tensors, leading to a shape mismatch. Maybe a simple fix would be to forcibly set past_key_value to None once this_peer_finished is set to True?

ojh31 commented 3 months ago

I can confirm the same error after upgrading to accelerate 0.33.0 and transformers 4.44.0

ojh31 commented 3 months ago

I wrote a slightly inefficient fix here

gante commented 2 months ago

👋 Hi @ojh31, thank you for opening this issue!

I believe this issue is the same as in #32885. I'd like the fix to be slightly different from the one you proposed, mostly due to an ongoing refactor on our end. Have a look at my comment here

(Redirecting to the other thread to avoid multiple parallel discussions; I know this issue is older, I take issues in a LIFO queue 🤗 )

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

gante commented 1 month ago

34095 fixes this issue on the vast majority of models 🤗