huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.59k stars 1.2k forks source link

Drop `use_cache=False if training_args.gradient_checkpointing` #1798

Closed qgallouedec closed 2 months ago

qgallouedec commented 3 months ago

If I understand properly, these lines:

https://github.com/huggingface/trl/blob/b6af2edc93b275afcee22a3eb71f9a5702ff9fd8/examples/scripts/dpo.py#L111

are because, previously, using cache with gradient checkpointing was broken, see https://github.com/huggingface/trl/issues/145#issuecomment-1459735966

Since https://github.com/huggingface/transformers/issues/21737 has been resolved, I think we can replace these lines by

    use_cache=model_args.use_cache,

wdyt?

vwxyzjn commented 3 months ago

Oh that will be really great! Could you test it out with some larger models to see if it indeed works? E.g., mistral 7B or gemma 2 24B.

qgallouedec commented 2 months ago

Actually, there are still these warnings everywhere:

https://github.com/huggingface/transformers/blob/1082361a1978d30db5c3932d1ee08914d74d9697/src/transformers/models/blip_2/modeling_blip_2.py#L970

I'll investigate further