tlc4418 / llm_optimization

A repo for RLHF training and BoN over LLMs, with support for reward model ensembles.
https://arxiv.org/abs/2310.02743
MIT License
26 stars 1 forks source link

My BoN result is different from the result (Figure 3a) shown in the paper. #13

Open yiwan-rl opened 1 month ago

yiwan-rl commented 1 month ago

I tried to reproduce the BoN result in the paper but failed to do so. I would like to ask if anyone has successfully reproduced it. And, I also list my steps to obtain the BoN result in this post.

First, let me show my BoN result.

Screenshot 2024-08-06 at 9 51 34 AM Screenshot 2024-08-06 at 1 49 06 PM

This result should be compared with the single RM curves in Figure 3a of the paper. My BoN result is different from the paper's result in two ways: 1) the initial gold RM scores are different, and 2) the final gold RM scores are different.

Here are the changes that I made to the codebase.

  1. add residual_dropout_lima: false to https://github.com/tlc4418/llm_optimization/blob/f8a9ae6c5ff907deb206efffe9ddb3e62d279e85/configs/config_rm.yaml#L49 (suggested in https://github.com/tlc4418/llm_optimization/issues/5).
  2. change https://github.com/tlc4418/llm_optimization/blob/f8a9ae6c5ff907deb206efffe9ddb3e62d279e85/configs/config_rm.yaml#L54 from alpaca_farm_pref to custom_hf_pref because the paper used the latter dataset for RM training.
  3. The RM training loss is invariant to offsets, meaning that shifting an RM's outputs by any constant offset will not change the RM loss. Therefore the resulting RMs have different offsets and we should remove them. In addition, the author suggested that the centered rewards are further divided by the estimated standard deviation in https://github.com/tlc4418/llm_optimization/issues/12. I did this normalization in this repo, not in the open-assistant codebase as suggested by the author. I think effectively my way and the author's way are the same. Specifically, I replaced the get_reward function in https://github.com/tlc4418/llm_optimization/blob/f8a9ae6c5ff907deb206efffe9ddb3e62d279e85/src/reward_modeling/scoring/score.py#L18 with the following code:

    def get_reward(
    samples,
    reward_models,
    reward_tokenizer,
    reward_device,  # needed?
    batch_size,
    objective_function=None,
    weight=None,
    is_alpacafarm_rm=False,
    normalize_reward=True,
    ):
    if not isinstance(reward_models, list):
        reward_models = [reward_models]
    
    input = reward_tokenizer(
        samples,
        padding=True,
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt",
    ).to(reward_device)
    
    all_rewards = []
    for reward_model in reward_models:
        out = []
        for i in range(math.ceil(len(samples) / batch_size)):
            batch_ixs = slice(i * batch_size, (i + 1) * batch_size)
            input_ids = input.input_ids[batch_ixs]
            attention_mask = input.attention_mask[batch_ixs]
            output = reward_model(input_ids, attention_mask)
            rewards = output.rewards if is_alpacafarm_rm else output.logits[:, 0]
            out.extend(rewards)
        all_rewards.append(torch.hstack(out))
    
    if len(all_rewards) == 1:
        all_rewards = all_rewards[0]
        # add normalization here
        if normalize_reward:
            all_rewards = (all_rewards - reward_models[0].config.mean) / reward_models[0].config.std
        return all_rewards, torch.empty_like(all_rewards)
    
    # add normalization here
    if normalize_reward:
        for i in range(len(reward_models)):
            all_rewards[i] = (all_rewards[i] - reward_models[i].config.mean) / reward_models[i].config.std
    all_rewards = torch.stack(all_rewards, 0)
    var = torch.var(all_rewards, dim=0)
    if objective_function:
        all_rewards = objective_function(all_rewards, weight)
    
    return all_rewards, var

    In addition, when training RMs, the RM scores should not be normalized, so let's put rewards, _ = get_reward(samples, model, tokenizer, model.device, batch_size=128, normalize_reward=False) in https://github.com/tlc4418/llm_optimization/blob/f8a9ae6c5ff907deb206efffe9ddb3e62d279e85/src/reward_modeling/training/trainer_rm.py#L341.

  4. To avoid the error mentioned in https://github.com/tlc4418/llm_optimization/issues/7, change gold_labelled_generations.map(_truncate_answers) to gold_labelled_generations.map(_truncate_answers, batched=True, batch_size=10) in https://github.com/tlc4418/llm_optimization/blob/f8a9ae6c5ff907deb206efffe9ddb3e62d279e85/src/bon/run_bon_pipeline.py#L77

My steps to obtain the result:

  1. Training RMs: Run accelerate launch --config_file configs/accelerate_config.yaml src/reward_modeling/training/trainer_rm.py --configs defaults_rm rm-pythia-44m --rng_seed <seed> for 5 times, with <seed> being 1, 2, 3, 4, and 5. The final summary for seed 1 is
    wandb: Run summary:
    wandb:           eval/custom_hf_pref_accuracy 0.6835
    wandb:         eval/custom_hf_pref_kendalltau 0.367
    wandb:               eval/custom_hf_pref_loss 0.53293   
    wandb:          eval/custom_hf_pref_neg_score -0.59062  
    wandb:          eval/custom_hf_pref_pos_score 0.35762   
    wandb:            eval/custom_hf_pref_runtime 8.5303
    wandb: eval/custom_hf_pref_samples_per_second 234.459   
    wandb:         eval/custom_hf_pref_score_diff 0.94824   
    wandb:   eval/custom_hf_pref_steps_per_second 29.307
    wandb:                            train/epoch 5.0
    wandb:                      train/global_step 7185
    wandb:                    train/learning_rate 0.0
    wandb:                             train/loss 0.512
    wandb:                       train/total_flos 0.0
    wandb:                       train/train_loss 0.56537   
    wandb:                    train/train_runtime 1976.5772 
    wandb:         train/train_samples_per_second 116.363   
    wandb:           train/train_steps_per_second 3.635
  2. Running BoN: Run python src/bon/run_bon_pipeline.py models/rm-pythia-44m_seed{seed} --seeds 1,2,3,4,5 --ensembles
  3. Visualization code:
    
    # proxy scores
    import json
    import math
    import matplotlib.pyplot as plt
    seed_fn_prefix='runs/bon_sampling_2024-08-05_05-49-25/seed'
    seed_fn_suffix = '/bon_sampled_results.json'
    seeds = [1, 2, 3, 4, 5]
    for s in seeds:
    f = open(seed_fn_prefix + str(s) + seed_fn_suffix)
    bon_res = json.load(f)   
    xs, ys_proxy, ys_gold = [], [], []
    for entry in bon_res:
        xs.append(math.log(entry['n']) - (entry['n'] - 1) / entry['n'])
        ys_proxy.append(entry['proxy_score'])
        ys_gold.append(entry['gold_score'])
    plt.plot(xs, ys_proxy, label="seed " + str(s))
    f.close()

plt.xlabel("KL") plt.ylabel("proxy reward") plt.legend() plt.show()

gold scores

seeds = [1, 2, 3, 4, 5] for s in seeds: f = open(seed_fn_prefix + str(s) + seed_fn_suffix) bon_res = json.load(f)
xs, ys_gold = [], [] for entry in bon_res: xs.append(math.log(entry['n']) - (entry['n'] - 1) / entry['n']) ys_gold.append(entry['gold_score']) plt.plot(xs, ys_gold, label="seed " + str(s)) f.close() plt.ylim(0, 1.1) plt.xlabel("KL") plt.ylabel("gold reward") plt.legend() plt.show()