huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.99k stars 1.26k forks source link

`OnPolicyConfig` - Rename or revise `num_sample_generations` #2024

Open RylanSchaeffer opened 2 months ago

RylanSchaeffer commented 2 months ago

Feature request

The OnPolicyConfig has the parameter num_sample_generations which controls how many times .generate_completions() will be called in a single run. This approach is unusual for two reasons:

  1. (Minor) The naming is confusing. I thought that num_sample_generations meant the number of (dataset) samples for which generations are drawn, not the number of times that .generate_completions() will be called.
  2. (Major) The parameter convention contradicts HuggingFace's Trainer() paradigm. Specifically, in Trainer(), there is a training argument eval_steps which specifies how many update steps should performed before evaluation is performed; in contrast, here, we must specify the number of .generate_completions() that will be called.

It's especially odd because the frequency is calculated internally e.g., https://github.com/huggingface/trl/blob/main/trl/trainer/ppov2_trainer.py#L138. Why not just let the user directly specify the frequency in the same manner as TrainingArguments supports.

I want this parameter to either (1) be renamed or (2) be alternatively controlled in a manner consistent with TrainingArguments's eval_steps.

Motivation

Yes, the current implementation is unnecessarily effortful. If I want to look at samples during my PPO runs, I need to work out how many updates steps each config will execute, and then work out what choice of num_sample_generations will enable me to see generated outputs at a reasonable frequency.

Your contribution

I could submit a PR if the maintainers agree.

RylanSchaeffer commented 2 months ago

I realized the calculation to set num_sample_generations is actually a little more effortful than I previously thought! One needs to:

  1. Compute the batch_size
  2. Compute the num_total_batches by dividing the total_episodes by the previously computed batch_size
  3. Compute sample_generations_freq by dividing num_total_batches // num_sample_generations

Now, I want a specific sample_generations_freq (e.g., 100), so now I need to backsolve. It would be much simpler if I could just specify sample_generations_freq, and this would be more consistent with TrainingArguments