huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.05k stars 27.02k forks source link

Generate: `can_generate()` recursive check #33718

Closed gante closed 1 month ago

gante commented 1 month ago

What does this PR do?

The deprecation warning added in #33203 was incomplete: classes inheriting from classes that have received the update were throwing the warning. This PR adds a recursive check to can_generate() and tests the warnings thrown by the function.

Example of script that is throwing a warning, but shouldn't (and is fixed in this PR):

from transformers.models.llama import LlamaForCausalLM

class NewLlamaForCausalLM(LlamaForCausalLM):
    pass

model = NewLlamaForCausalLM.from_pretrained("hf-internal-testing/tiny-random-LlamaForCausalLM")
model.generate()

Thank you @regisss for warning me about this one and providing a reproducer 🙏

HuggingFaceDocBuilderDev commented 1 month ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

regisss commented 1 month ago

Ok! @regisss can you comment whether this fixes the issue you encountered?

Yep it's all good on my side, the warning doesn't appear anymore with this fix. Thanks!