huggingface / trl

Train transformer language models with reinforcement learning.
http://hf.co/docs/trl
Apache License 2.0
9.92k stars 1.25k forks source link

Negative Entropy in TRL PPOv2Trainer TLDR Example #2022

Open RylanSchaeffer opened 1 month ago

RylanSchaeffer commented 1 month ago

System Info

Information

Tasks

Reproduction

In TRL's PPOv2Trainer TLDR example, run the default command:

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
    examples/scripts/ppo/ppo_tldr.py \
    --output_dir models/minimal/ppo_tldr \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 4 \
    --total_episodes 1000000 \
    --model_name_or_path EleutherAI/pythia-1b-deduped \
    --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
    --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
    --local_rollout_forward_batch_size 16 \
    --non_eos_penalty \
    --stop_token eos

Expected behavior

Entropy for a discrete distribution (such as that of a language model) must be non-negative. However, when I run the official example, the entropy can be negative:

image

I don't think I'm making a mistake because this negative entropy also appears in the official documentation. Specifically, look early in training, at maybe 20k episodes:

image

The documentation describes objective/entropy as "The mean entropy of the policy, indicating the randomness of the actions chosen by the policy." If this is incorrect, and some other quantity is computed instead, then perhaps the documentation needs to be updated?

RylanSchaeffer commented 1 month ago

I don't know if this is the culprit, but I noticed that the tutorial and I both use bf16, and in bf16, the two following quantities don't agree:

torch.einsum("bse,bse->bs", prob_dist, logits) - torch.sum(prob_dist * logits, dim=-1)

The difference is non-zero:

tensor([[ 0.0000,  0.1250, -0.1250,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.1250,  0.0000, ...0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000]], device='cuda:0',
       dtype=torch.bfloat16)
RylanSchaeffer commented 1 month ago

Following this previous PR, it might be worthwhile to consider upcasting the tensors before computing logged quantities.

But I don't know if this explains how the entropy is becoming negative...

RylanSchaeffer commented 1 month ago

On another PPOv2 run, I again observe negative entropy:

image