huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.92k stars 26.99k forks source link

EncoderDecoderModel loss function #7946

Closed AI678 closed 4 years ago

AI678 commented 4 years ago

❓ Questions & Help

Details

Hey , I want to ask the following questions. How is the loss calculated in DecoderEncoderModel. What is the mathematic formula of the loss function ? I just wrote the code like this

outputs = model(input_ids=src, attention_mask=mask, decoder_input_ids=dst, labels=dst, return_dict=True) loss, logits = outputs.loss, outputs.logits

A link to original question on the forum/Stack Overflow:

LysandreJik commented 4 years ago

Hello, this depends on the decoder you use to initialize the encoder-decoder model. What decoder do you use?

AI678 commented 4 years ago

I use 'bert-base-uncased'. just like this model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

LysandreJik commented 4 years ago

I'm not sure this is the recommended way to load the models as it gives the following result:

Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', [...]

with pretty much all model weights.

Will ping @patrickvonplaten for advice.

patrickvonplaten commented 4 years ago

Hey @AI678,

1) The model should be initialized just as you did with

model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

It's normal that None of the cross-attention layers are initialized because BERT does not have any and they have to be fine-tuned down the road.

2) To Train a Bert2Bert, you are also correct in doing:

outputs = model(input_ids=src, attention_mask=mask, decoder_input_ids=dst, labels=dst, return_dict=True)
loss, logits = outputs.loss, outputs.logits

because BERT automatically shifts the labels for you, see: https://github.com/huggingface/transformers/blob/901e9b8eda2fe88af717f960ddc05cac1803679b/src/transformers/modeling_bert.py#L1060

Also I'll publish a more in-detail notebook about "Leveraging Encoder-Decoder models" soon. This model card could also be helpful: https://huggingface.co/patrickvonplaten/bert2bert-cnn_dailymail-fp16#bert2bert-summarization-with-%F0%9F%A4%97-encoderdecoder-framework

AI678 commented 4 years ago

thank you very much