Closed gyin94 closed 3 years ago
hi @gyin-ai
Thank you for reporting the issue. The run_seq2seq.py
currently does not work for encoder-decoder models. This is because the encoder-decoder models expect both decoder_input_ids
and labels
whereas the script only passes the labels
. Which is causing the above error.
You could refer to this notebook to see how to use Trainer
for encoder-decoder models. Also, you easily adapt the run_seq2seq.py
script for this, I think you'll only need to change the data collator here to return both the labels
and decoder_input_ids
@patil-suraj can I ask whether batch["decoder_input_ids"]
should be inputs.input_ids
instead of outputs.input_ids
?
def process_data_to_model_inputs(batch):
# Tokenizer will automatically set [BOS] <text> [EOS]
inputs = tokenizer(batch["document"], padding="max_length", truncation=True, max_length=encoder_max_length)
outputs = tokenizer(batch["summary"], padding="max_length", truncation=True, max_length=decoder_max_length)
batch["input_ids"] = inputs.input_ids
batch["attention_mask"] = inputs.attention_mask
batch["decoder_input_ids"] = outputs.input_ids
batch["labels"] = outputs.input_ids.copy()
# mask loss for padding
batch["labels"] = [
[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]
]
batch["decoder_attention_mask"] = outputs.attention_mask
return batch
here is the example from EncoderDecoderModel
>>> from transformers import EncoderDecoderModel, BertTokenizer
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
>>> # forward
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
>>> # training
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)
The labels
and decoder_input_ids
always correspond to output. so it should be outputs.input_ids
Environment info
transformers
version: 4.4.0.dev0path_to_csv_or_jsonlines_file:
t5-small works perfectly. But BertGeneration model has the following error
error: