Closed AI678 closed 4 years ago
Hello, this depends on the decoder you use to initialize the encoder-decoder model. What decoder do you use?
I use 'bert-base-uncased'. just like this model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')
I'm not sure this is the recommended way to load the models as it gives the following result:
Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', [...]
with pretty much all model weights.
Will ping @patrickvonplaten for advice.
Hey @AI678,
1) The model should be initialized just as you did with
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')
It's normal that None
of the cross-attention layers are initialized because BERT does not have any and they have to be fine-tuned down the road.
2) To Train a Bert2Bert, you are also correct in doing:
outputs = model(input_ids=src, attention_mask=mask, decoder_input_ids=dst, labels=dst, return_dict=True)
loss, logits = outputs.loss, outputs.logits
because BERT automatically shifts the labels for you, see: https://github.com/huggingface/transformers/blob/901e9b8eda2fe88af717f960ddc05cac1803679b/src/transformers/modeling_bert.py#L1060
Also I'll publish a more in-detail notebook about "Leveraging Encoder-Decoder models" soon. This model card could also be helpful: https://huggingface.co/patrickvonplaten/bert2bert-cnn_dailymail-fp16#bert2bert-summarization-with-%F0%9F%A4%97-encoderdecoder-framework
thank you very much
❓ Questions & Help
Details
Hey , I want to ask the following questions. How is the loss calculated in DecoderEncoderModel. What is the mathematic formula of the loss function ? I just wrote the code like this
outputs = model(input_ids=src, attention_mask=mask, decoder_input_ids=dst, labels=dst, return_dict=True) loss, logits = outputs.loss, outputs.logits
A link to original question on the forum/Stack Overflow: