huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.4k stars 26.37k forks source link

Which language is available for EncoderDecoderModel pre-trained model? #14381

Closed hansd410 closed 2 years ago

hansd410 commented 2 years ago

Hi, I have a question.

From the example code of EncoderDecoderModel, pre-trained model 'bert-base-uncased' looks trained by EN-FR sentence pairs. Is pre-trained model for decoder is trained in French? How can I load pre-trained model in EN-EN pair? Am I confusing?

from transformers import EncoderDecoderModel, BertTokenizer
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints

# training
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

input_ids = tokenizer("Hello, my dog is cute", return_tensors="pt").input_ids
labels = tokenizer("Salut, mon chien est mignon", return_tensors="pt").input_ids
outputs = model(input_ids=input_ids, labels=input_ids)
loss, logits = outputs.loss, outputs.logits

# save and load from pretrained
model.save_pretrained("bert2bert")
model = EncoderDecoderModel.from_pretrained("bert2bert")

# generation
generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
NielsRogge commented 2 years ago

Hi,

Not sure what you mean, bert-base-uncased is the English pretrained BERT checkpoint. So in the code example above, we initialize the weights of the encoder and the decoder with the weights of BERT, and the weights of the cross-attention layers of the decoder are randomly initialized. One should fine-tune this warm-started model on an English downstream dataset, like summarization.

hansd410 commented 2 years ago

Thank you for reply @NielsRogge . Then I want to ask, why the decoder output is French in example code?

labels = tokenizer("Salut, mon chien est mignon", return_tensors="pt").input_ids

NielsRogge commented 2 years ago

Oh I get your confusion.

The EncoderDecoderModel framework is meant to be fine-tuned on text-to-text datasets, such as machine translation. Of course, if you want to fine-tune an EncoderDecoderModel to perform translation from English to French, it makes sense to warm-start the decoder with a pre-trained French checkpoint, e.g.:

model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'camembert-base') I'll update the code example. Also note that you then should use CamemBERT's tokenizer to create the labels.

NielsRogge commented 2 years ago

@hansd410 do you mind opening a PR to fix this is in the docs?

hansd410 commented 2 years ago

@NielsRogge Not at all. Please do so.