huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.67k stars 27.16k forks source link

Use models as Seq2Seq model #30529

Open Bachstelze opened 7 months ago

Bachstelze commented 7 months ago

System Info

Who can help?

@ArthurZucker @muellerzr @stevhliu

Information

Tasks

Reproduction

There is this snippet in many model documentations: To be used in a Seq2Seq model, the model needs to initialized with both is_decoder=True and bidirectional=False argument as well as add_cross_attention set to True; an encoder_hidden_states is then expected as an input to the forward pass.

I try it like this for the MEGA model:

from transformers import MegaConfig

# config for a small seq2seq model like in the MEGA paper
config = MegaConfig(
    vocab_size=vocabulary_size,
    max_position_embeddings=context_length,
    is_decoder=True,
    bidirectional=False,
    add_cross_attention=True
)

from transformers import AutoTokenizer, MegaModel,MegaForCausalLM

model = MegaModel(config=config)
# only the causalLM as decoder-only seems to run
#model = MegaForCausalLM(config=config)

The following error occurs, when training it with Seq2SeqTrainer,Seq2SeqTrainingArguments and DataCollatorForSeq2Seq:

[transformers/models/mega/modeling_mega.py](https://localhost:8080/#) in forward(self, hidden_states, attention_mask, causal_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, use_cache)
   1271         if self.cross_attn is not None:
   1272             if encoder_hidden_states is None:
-> 1273                 raise ValueError("Requested cross-attention without providing encoder hidden states")
   1274 
   1275             cross_attn_outputs = self.cross_attn(

ValueError: Requested cross-attention without providing encoder hidden states

Expected behavior

Train MEGA like a seq2seq as in their paper.

Bachstelze commented 6 months ago

@mnaylor5 do you have any advice on how to use MEGA as seq2seq? Or should the original fairseq implementation be used for this case?

mnaylor5 commented 6 months ago

@Bachstelze - I'm not working on this at the moment (I just contributed MEGA to transformers by translating the architecture from the original fairseq code), but it seems like you're trying to use an encoder-decoder model setup (seq2seq) without an encoder. As the error message states, a MEGA model using cross-attention will need encoder hidden states to be provided as an argument to the forward pass. For example, in standard seq2seq models like BART or T5, you would have an encoder portion (with bidirectional self-attention) which does a forward pass on the source/context sequence, and you would have a separate decoder portion (with unidirectional self-attention in addition to cross-attention) which does a forward pass on the target sequence and accepts the final encoder hidden states as an input for cross-attention.

The MEGA implementation is set up very similarly to BERT/RoBERTa, so any resources that demonstrate how to use those models in a seq2seq setting should also apply to MEGA. I haven't personally done much in the seq2seq space, so I'll defer to the fine folks at Hugging Face to provide guidance on that 😄