huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.4k stars 26.37k forks source link

[EncoderDecoder] Make sure `use_cache` is set to `True` for all Bert2Bert, Roberta2Roberta by default #9456

Closed patrickvonplaten closed 3 years ago

patrickvonplaten commented 3 years ago

At the moment when one loads a Bert2Bert:

model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")

does not automatically set use_cache to True -> so that the user "silently" has to be fine with a much slower than optimal inference speed. Also all Bert2Bert configs online don't have use_cache set to True. This should be changed at least for the heavily used Bert2Bert models.

I'll try to take care of that in the next couple days. Also pinging @patil-suraj for information. Thanks @Narsil for binging up the topic.

patil-suraj commented 3 years ago

@patrickvonplaten

Can we instead set use_cache to True by default in generate? That way we won't need to rely on config

Right now, the generate docstring says that it defaults to True, but it's set to None

https://github.com/huggingface/transformers/blob/28d74872cc049e0cbee3fafd15cbbabfe348ebd4/src/transformers/generation_utils.py#L618

patrickvonplaten commented 3 years ago

Hmm, that goes a bit against the philosophy because we never "set" any variables in generate(). We should do it in EncoderDecoderConfig and in from_encoder_decoder_pretrained. Note that all args in generate() are set to None, but default to the respective config defaults which should be set correctly

patil-suraj commented 3 years ago

Also use_cache is newly introduced in bert/roberta config and is True by default, so even if the model's config file online doesn't have use_cache it should still be True, no?

Could you maybe provide an example where the above issue occurs?

patrickvonplaten commented 3 years ago

@patil-suraj, you're 100% right!

I initially thought it's a problem because EncoderDecoderConfig does not have a use_cache param set to True, but it doesn't actually matter since model.decoder.config.use_cache will always be set to True by default which forces use_cache to be True in the decoder which makes it return the past_key_values => so all good then - thanks a lot for double-checking this :-)