huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
130.82k stars 26.03k forks source link

Llama Attention Call should not pass **kwargs #30523

Closed kiddyboots216 closed 1 month ago

kiddyboots216 commented 3 months ago

System Info

Who can help?

@ArthurZucker

Information

Tasks

Reproduction

Wrapping a LlamaModel with FSDP results in the following error during a forward pass;

  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1196, in forward
    outputs = self.model(
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1016, in forward
    layer_outputs = decoder_layer(
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 849, in forward
    output = self._fsdp_wrapped_module(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 168, in forward
    return self.checkpoint_fn(  # type: ignore[misc]
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 489, in checkpoint
    ret = function(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 739, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/scratch/gpfs/ashwinee/envs/align/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: LlamaSdpaAttention.forward() got an unexpected keyword argument 'offload_to_cpu'

This occurs because we are passing kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L749 to a function that does not accept kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L608

If we use another model, ex Mistral, this issue does not occurs, because we don't pass **kwargs https://github.com/huggingface/transformers/blob/main/src/transformers/models/mistral/modeling_mistral.py#L757C63-L757C77

Expected behavior

Remove line 749 or add **kwargs to forward().

ArthurZucker commented 3 months ago

Kwargs should indeed not be passed. I would need a reproducer but feel free to open a PR for a fix! 😉

kiddyboots216 commented 3 months ago

Kwargs should indeed not be passed. I would need a reproducer but feel free to open a PR for a fix! 😉

I will open a PR after cataloguing all the models that have this issue. Gptneox also has this issue. Reproducer is to wrap a model in FSDP and then do a forward on any data.

vikram71198 commented 3 months ago

Yep, can confirm I also see the same issue with LLaMA-3-8b-Instruct with FSDP + Gradient Checkpointing.

The Yi series of models also have this issue, I just checked. And it makes perfect sense since they follow the LLaMA architecture.

ArthurZucker commented 3 months ago

We'll remove the kwargs! cc @zhenglongjiepheonix who is working on something related! We can open a separate PR for this and link this issue

vikram71198 commented 3 months ago

@ArthurZucker are there any updates on this? I don't see a PR for this yet.

ArthurZucker commented 3 months ago

30743 closed this 😉

vikram71198 commented 3 months ago

@ArthurZucker there's a slightly different error this time around with transformers==4.41.0

LlamaDecoderLayer.forward() got an unexpected keyword argument 'offload_to_cpu'

There's something going on here.

ArthurZucker commented 2 months ago
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

that seems pretty impossible TBH, we don't have kwargs or whatnot. Make sure you are on 4.41.1

vikram71198 commented 2 months ago

Yeah, even with transformers == 4.41.1, I get the error:

TypeError: LlamaDecoderLayer.forward() got an unexpected keyword argument ```offload_to_cpu```

And given the signature of the forward() of this layer in your message above, it seems like it doesn't take the input offload_to_cpu, so this error message actually makes sense.

Is it so that LLaMA models are not compatible with Gradient Checkpointing when training with FSDP because offload_to_cpu is a param which comes from the Gradient Checkpointing.

ArthurZucker commented 2 months ago

cc @younesbelkada

younesbelkada commented 2 months ago

On it !

vikram71198 commented 2 months ago

Okay, with transformers == 4.41.1, I'm now getting this same aforementioned error with Mistral models as well:

MistralDecoderLayer.forward() got an unexpected keyword argument 'offload_to_cpu'

This wasn't happening with Mistral before.

Again, this is when I use Gradient Checkpointing with FSDP.

FYI, this is a rough template of how I apply Gradient Checkpointing with FSDP:

check_fn = lambda submodule: isinstance(submodule, MistralDecoderLayer)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

apply_activation_checkpointing(
    model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

This is the method suggested by the official FSDP tutorials from the PyTorch team here.

With transformers==4.40.2, this issue does not appear with Mistral models, but with LLaMA models (including Yi), the issue mentioned at the very start in the description of this issue is what I see.

younesbelkada commented 2 months ago

I was able to repro:

from functools import partial
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator

# verify we have FSDP activation support ready by importing:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)

from transformers.models.llama.modeling_llama import LlamaDecoderLayer

model_id = "HuggingFaceM4/tiny-random-Llama3ForCausalLM"

model = AutoModelForCausalLM.from_pretrained(model_id)

model.train()
model.gradient_checkpointing_enable()

accelerator = Accelerator()
model = accelerator.prepare(model)

check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer)

non_reentrant_wrapper = partial(
    checkpoint_wrapper,
    offload_to_cpu=False,
    checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

apply_activation_checkpointing(
    model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn
)

print(model)

rand_input = torch.LongTensor([[0, 1, 0, 1]]).to(0)

model(rand_input)

31161 fixes the bug