huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
10.13k stars 1.28k forks source link

Set gradient_checkpointing_kwargs in the yaml #2334

Open Galaxy-Husky opened 2 weeks ago

Galaxy-Husky commented 2 weeks ago

Feature request

Hi,

Now gradient_checkpointing_kwargs is set like: https://github.com/huggingface/trl/blob/ac77c092235e1218917d53a6832ac2b8ca48198c/examples/scripts/bco.py#L109 I was wondering if we could set it in a config yaml because I noticed a deprecated function: https://github.com/huggingface/trl/blob/ac77c092235e1218917d53a6832ac2b8ca48198c/trl/commands/cli_utils.py#L171-L192

Motivation

I think it's easier for training by setting gradient_checkpointing_kwargs in the yaml.

Your contribution

If you agree, I would happy to submit a PR to support this feature by modifying the function above.

Galaxy-Husky commented 2 days ago

@qgallouedec hi, do you have any suggestions?