huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.8k stars 26.96k forks source link

Allow gradient for generate() #28319

Open whitejeep600 opened 10 months ago

whitejeep600 commented 10 months ago

Feature request

The generate function is decorated with @torch.no_grad() and thus can't be used for model training. It would be better to make calculating gradients optional, rather than impossible, so that the function can be used for tuning. The simplest solution is to remove the decorator altogether, as users can set no_grad themselves before calling if they need to. Are there reasons to disable such usage?

Motivation

Allow using generate for tuning

Your contribution

Removing the decorator is a very simple change. I can submit a PR

ArthurZucker commented 10 months ago

The reason is that generate is for inference only, training requires a custom sampling logic

whitejeep600 commented 10 months ago

The reason is that generate is for inference only, training requires a custom sampling logic

Thank you for the quick reply.

However, what if I would just like to use the existing sampling algorithms from the generate function? It's a valid training strategy. It would be convenient to have access to the ready-made code for this purpose.