huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.41k stars 26.37k forks source link

unexpected keyword argument 'lm_labels' when using BertModel as Decoder with EncoderDecoderModel #4960

Closed utkd closed 4 years ago

utkd commented 4 years ago

The BertModel.forward() method does not expect a lm_labels and masked_lm_labels arguments. Yet, it looks like the EncoderDecoderModel.forward() method calls it's decoder's forward() method with those arguments which throws a TypeError when a BertModel is used as a decoder.

Am I using the BertModel incorrectly? I can get rid of the error by modifying the EncoderDecoderModel to not use those arguments for the decoder.

Exact Error:

File "/Users/utkarsh/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/utkarsh/Projects/ai4code/transformers/bert2bert/models.py", line 12, in forward
    dec_out, dec_cls, enc_out, enc_cls = self.bertmodel(input_ids=inputs, attention_mask=input_masks, decoder_input_ids=targets, decoder_attention_mask=target_masks)
  File "/Users/utkarsh/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/Users/utkarsh/anaconda3/envs/py37/lib/python3.7/site-packages/transformers/modeling_encoder_decoder.py", line 283, in forward
    **kwargs_decoder,
  File "/Users/utkarsh/anaconda3/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
TypeError: forward() got an unexpected keyword argument 'lm_labels'

Relevant part of the code:

encoder = BertModel(enc_config)
dec_config = BertConfig(...,is_decoder=True)
decoder = BertModel(dec_config)
model = EncoderDecoderModel(encoder=encoder, decoder=decoder)

... dec_out, dec_cls, enc_out, enc_cls = model(input_ids=inputs, attention_mask=input_masks, decoder_input_ids=targets, decoder_attention_mask=target_masks)

gustavscholin commented 4 years ago

I'm facing the same problem. Since #4874 it seems like it should be just labels instead of lm_labels. According to the documentation it should do masked language modeling-loss, but from my debugging it seems like it actually does next word prediction-loss.

stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.