huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.87k stars 26.26k forks source link

Shape mismatch when generating with multiple processes #32603

Open ojh31 opened 1 month ago

ojh31 commented 1 month ago

System Info

Who can help?

@gante @SunMarc @ArthurZucker

Information

Tasks

Reproduction

Run accelerate launch --config_file=accelerate_config.yaml foo.py

foo.py:

from accelerate import Accelerator
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoModelForCausalLM, LlamaTokenizer

NAME = "meta-llama/Llama-2-13b-chat-hf"
tokenizer = LlamaTokenizer.from_pretrained(
    NAME, 
    padding_side="left", 
)
accelerator = Accelerator()
model = AutoModelForCausalLM.from_pretrained(NAME)
model = accelerator.prepare(model)

paired_input_ids = torch.tensor([
    [    2,     2,     2,     2,     2,     2,     2,     2,     2,     2,
             1, 29961, 25580, 29962,   887,   526,   385,  7395,  6993,  1788,
         29889, 13866,   366,   674,   367,  5429,   278,  1788,  4800, 29889,
           450,  1404,   674,   769,  3896,   263,  4800, 29889,   960,   278,
          1404,  4800,  7087,   278,  1788,  4800, 29892,   736, 18016, 13566,
          3352, 29889,   960,   278,  1404,  4800,   947,   451,  1993,   278,
          1788,  4800, 29892,   736,   360,  1430, 29902,  3352, 29889, 18076,
           487,  3099,  1156,   278,  1404,  4800, 29889,    13,    13,  3924,
          4800, 29901, 15040,    13,  2659,  4800, 29901, 15040,    13,  5634,
            13,    13,    13, 22550, 29901,   518, 29914, 25580, 29962],
            [    2,     2,     2,     2,     2,     2,     1, 29961, 25580, 29962,
           887,   526,   385,  7395,  6993,  1788, 29889, 13866,   366,   674,
           367,  5429,   278,  1788,  4800, 29889,   450,  1404,   674,   769,
          3896,   263,  4800, 29889,   960,   278,  1404,  4800,  7087,   278,
          1788,  4800, 29892,   736, 18016, 13566,  3352, 29889,   960,   278,
          1404,  4800,   947,   451,  1993,   278,  1788,  4800, 29892,   736,
           360,  1430, 29902,  3352, 29889, 18076,   487,  3099,  1156,   278,
          1404,  4800, 29889,    13,    13,  3924,  4800, 29901,  1757, 10582,
           284,    13,  2659,  4800, 29901,  1757, 10582,   284,    13,  5634,
            13,    13,    13, 22550, 29901,   518, 29914, 25580, 29962]
        ]
)
paired_attention_mask = torch.tensor([
    [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1],
    [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1]
])

paired_dataset = TensorDataset(paired_input_ids, paired_attention_mask)

dataloader = DataLoader(
    dataset=paired_dataset,
    batch_size=1,  # Process one pair at a time
    shuffle=False,
)
dataloader = accelerator.prepare(dataloader)

for batch_input_ids, batch_attention_mask in dataloader:
    with torch.no_grad():
        model.forward(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
    with FSDP.summon_full_params(model, recurse=False):
        outputs = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask, 
            tokenizer=tokenizer,
            synced_gpus=True,
        )

accelerate_config.yaml:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: "no"
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: "no"
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Expected behavior

Should generate text output, but instead throws error

The expanded size of the tensor (105) must match the existing size (104) at non-singleton dimension 3.  Target sizes: [1, 40, 1, 105].  Tensor sizes: [1, 1, 1, 104]
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 648, in forward
    attn_output = torch.nn.functional.scaled_dot_product_attention(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 718, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 978, in forward
    layer_outputs = decoder_layer(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1174, in forward
    outputs = self.model(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 2651, in _sample
    outputs = self(
  File "/usr/local/venv/lib/python3.10/site-packages/transformers/generation/utils.py", line 1914, in generate
    result = self._sample(
  File "/usr/local/venv/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/root/robust-llm/pairs.py", line 66, in <module>
    outputs = model.generate(
RuntimeError: The expanded size of the tensor (105) must match the existing size (104) at non-singleton dimension 3.  Target sizes: [1, 40, 1, 105].  Tensor sizes: [1, 1, 1, 104]

Hypothesis: In transformers/generation/utils.py::GenerationMixin_sample(), during the while self._has_unfinished_sequences() loop, we continue if synced_gpus and this_peer_finished. This results in not skipping the concatenation of next_tokens to input_ids. Whereas, we keep updating the past_key_value cache in transformers/models/llama/modeling_llama.py::LlamaSdpaAttention.forward(). Therefore, when one process finishes generation before the other, the finished process continues to expand the key-value cache but stops expanding the input tensors, leading to a shape mismatch. Maybe a simple fix would be to forcibly set past_key_value to None once this_peer_finished is set to True?

ojh31 commented 1 month ago

I can confirm the same error after upgrading to accelerate 0.33.0 and transformers 4.44.0

ojh31 commented 4 weeks ago

I wrote a slightly inefficient fix here

gante commented 1 week ago

👋 Hi @ojh31, thank you for opening this issue!

I believe this issue is the same as in #32885. I'd like the fix to be slightly different from the one you proposed, mostly due to an ongoing refactor on our end. Have a look at my comment here

(Redirecting to the other thread to avoid multiple parallel discussions; I know this issue is older, I take issues in a LIFO queue 🤗 )