huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
8.61k stars 1.06k forks source link

[Code Improvement] Support concatnate forward in reward trainer #1769

Open 1485840691 opened 1 week ago

1485840691 commented 1 week ago

This PR is to address a previous code improvement suggestion that in reward trainer, we could borrow the same idea from DPOTrainer to concatenate chosen and rejected tokens to save one model forward call(). The pitfall of this concatenate forward is increase GPU memory. So add a flag to control on/off of this improvement feature.

vwxyzjn commented 6 days ago

Looks like a great change! Thanks @1485840691 for the PR

HuggingFaceDocBuilderDev commented 6 days ago

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

vwxyzjn commented 6 days ago

Make sure you do make precommit

1485840691-eng commented 4 days ago

Make sure you do make precommit

Done precommit check. Please help review. Thanks