Closed qgallouedec closed 2 months ago
Oh that will be really great! Could you test it out with some larger models to see if it indeed works? E.g., mistral 7B or gemma 2 24B.
Actually, there are still these warnings everywhere:
I'll investigate further
If I understand properly, these lines:
https://github.com/huggingface/trl/blob/b6af2edc93b275afcee22a3eb71f9a5702ff9fd8/examples/scripts/dpo.py#L111
are because, previously, using cache with gradient checkpointing was broken, see https://github.com/huggingface/trl/issues/145#issuecomment-1459735966
Since https://github.com/huggingface/transformers/issues/21737 has been resolved, I think we can replace these lines by
wdyt?