allenai / reward-bench

RewardBench: the first evaluation tool for reward models.
https://huggingface.co/spaces/allenai/reward-bench
Apache License 2.0
277 stars 27 forks source link

Prompt Repeated in DPO `tokenize_row` (not actually sure if this is an issue) #140

Closed PootieT closed 3 weeks ago

PootieT commented 3 weeks ago

Just to premise this issue, I only spent 2 hours looking at the code, so not entirely sure if it is an user error or intentional.

But in dpo.py tokenize_row function, the string variables chosen and rejected seem to contain prompt already? Is it expected? (if not, probably user error)

If it is expected, then this seem to cause issue downstream where chosen_sequence_tokens["input_ids"] contains repeated prompts, and chosen_sequence_tokens["labels"] results in the same shape vector as chosen, but just with the first prompt masked (not the repeated prompt).

I looked at TRL and alignment-handbook, and checked dataset variables there, and seems like in their tokenize_row function, string variables chosen and rejected do not contain the prompt.

I called python scripts/run_dpo.py directly, and loaded the default dataset using

dataset, subsets = load_eval_dataset(
    core_set=not args.pref_sets,   # where args.pref_sets=False
    conv=conv,
    tokenizer=tokenizer,
    logger=logger,
    keep_columns=["text_chosen", "text_rejected", "id", "prompt"],
)
natolambert commented 3 weeks ago

@PootieT can you say more about what the delta is that you see? This tokenize_row function is almost copied explicitly from TRL (comment is at the top of the file, which we can make clearer).

I went through this rabbit hole once too, but I think it's okay?

PootieT commented 3 weeks ago

yeah sorry my original post is not as clear.

To reproduce, I run this (really just any model works, just need to load the default dataset):

python scripts/run_dpo.py --model=HuggingFaceH4/zephyr-7b-beta --batch_size=8 --debug=true

Then in this line (after the dataset is loaded, before tokenize_row is called), I take a look at the dataset loaded:

(Pdb) dataset["prompt"][0]
'<|user|>\nHow do I detail a car?</s>\n'
(Pdb) dataset["text_chosen"][0]
"<|user|>\nHow do I detail a car?</s>\n<|assistant|>\nDetailing a car involves a thorough cleaning inside and out, as well as polishing and waxing to protect the vehicle's surfaces. Here's a step-by-step ...

The text_chosen field contains the prompt!

Whereas in DPOTrainer.init() from TRL, right before the function tokenizer_row() is used on the dataset, the chosen field does NOT contain the prompt field, so something like

(Pdb) dataset["prompt"][0]
'<|user|>\nHow do I detail a car?</s>\n'
(Pdb) dataset["text_chosen"][0]
"<|assistant|>\nDetailing a car involves a thorough cleaning inside and out, as well as polishing and waxing to protect the vehicle's surfaces. Here's a step-by-step ...

Since tokenize_row function is pretty much the same between reward-bench and trl, I think this is not expected, and prompt should have been removed from text_chosen. If not removed, inside tokenize_row function, prompt is appended to text_chosen again, and the truncation logic with max_length and max_prompt_length no longer makes sense.

natolambert commented 3 weeks ago

Let's go through this. We definitely had looked at it before, but it still may be wrong. If we look a specific index, we need to see if the input tokens to the DPO method have repeated tokens with attention mask of 1's.

If you look at the tokenized chosen text, you can quickly see the attention mask is all one:

{'chosen_input_ids': [151645, 151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 4340, 653, 358, 7716, 264, 1803, 30, 151645, 198, 151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 4340, 653, 358, 7716, 264, 1803, 30, 151645, 198, 151644, 77091, 198, 10649, 287, 264, 1803, 17601, 264, 17423, 15826, 4766, 323, 700, 11, 438, 1632, 438, 84463, 323, 36023, 287, 311, 6016, 279, 7310, 594, 26431, 13, 5692, 594, 264, 3019, 14319, 29208, 8474, 311, 44193, 264, 1803, 1447, 334, 840, 40544, 25771, 287, 25, 56177, 16, 13, 3070, 54, 988, 279, 3261, 25, 1019, 256, 481, 57368, 325, 279, 1803, 448, 3015, 311, 4057, 20174, 25284, 624, 256, 481, 5443, 264, 1803, 11369, 26785, 323, 8003, 82945, 11369, 47332, 311, 4240, 279, 1803, 504, 1909, 311, 5622, 624, 256, 481, 9590, 279, 22696, 323, 30700, 448, 264, 15625, 323, 264, 13284, 31881, 624, 256, 481, 57368, 325, 279, 1803, 26364, 311, 4057, 678, 26785, 382, 17, 13, 3070, 85215, 279, 3261, 25, 1019, 256, 481, 5443, 264, 8003, 82945, 42613, 476, 264, 521, 21100, 285, 311, 9058, 279, 1803, 311, 5358, 3015, 18773, 382, 18, 13, 3070, 5066, 352, 4716, 30869, 25, 1019, 256, 481, 5443, 264, 36048, 3619, 448, 264, 53394, 517, 311, 4057, 22864, 7329, 87859, 504, 279, 6177, 382, 19, 13, 3070, 14658, 10976, 25, 1019, 256, 481, 20552, 1803, 44029, 448, 264, 18579, 25059, 1471, 38572, 476, 553, 1424, 311, 4396, 6177, 16772, 24384, 323, 1855, 264, 10876, 7329, 382, 20, 13, 3070, 54, 706, 287, 25, 1019, 256, 481, 20552, 264, 22875, 315, 36023, 476, 6177, 25349, 517, 311, 6016, 279, 6177, 323, 2968, 432, 264, 73056, 6248, 382, 21, 13, 3070, 13164, 323, 14268, 31686, 25, 1019, 256, 481, 9590, 279, 11030, 323, 40485, 448, 264, 8991, 31881, 323, 264, 8003, 82945, 42613, 382, 22, 13, 3070, 51, 554, 323, 44376, 28218, 287, 25, 1019, 256, 481, 20552, 264, 27287, 31523, 311, 279, 30700, 369, 264, 41199, 6248, 624, 256, 481, 5443, 264, 11013, 2732, 14827, 476, 6016, 517, 389, 12188, 323, 22674, 5479, 311, 5358, 58517, 382, 334, 85125, 25771, 287, 25, 56177, 16, 13, 3070, 13021, 70623, 25, 1019, 256, 481, 12023, 700, 894, 22854, 323, 4057, 4345, 3589, 504, 279, 1803, 382, 17, 13, 3070, 81789, 20434, 25, 1019, 256, 481, 75010, 279, 16312, 11, 88241, 11, 6422, 61056, 11, 323, 37311, 624, 256, 481, 5443, 264, 15625, 19984, 369, 279, 26967, 323, 6006, 21285, 382, 18, 13, 3070, 2016, 47595, 48777, 1415, 323, 98905, 337, 75970, 25, 1019, 256, 481, 5443, 264, 27854, 31881, 323, 264, 15625, 311, 4240, 279, 88241, 323, 95688, 624, 256, 481, 1752, 17553, 73464, 11, 990, 264, 17553, 31881, 323, 64324, 382, 19, 13, 3070, 27529, 11232, 8082, 7605, 25, 1019, 256, 481, 467, 3444, 1495, 678, 2588, 26431, 320, 18641, 11, 4126, 2339, 11, 6006, 21285, 11, 4992, 6138, 448, 264, 23034, 678, 58238, 31881, 323, 264, 8003, 82945, 27292, 382, 20, 13, 3070, 13164, 323, 14268, 31686, 25, 1019, 256, 481, 9590, 279, 14791, 3108, 315, 11030, 323, 40485, 382, 21, 13, 3070, 25693, 647, 805, 323, 356, 7282, 1216, 25, 1019, 256, 481, 5443, 264, 44193, 15625, 476, 30649, 3720, 311, 4240, 700, 3720, 80207, 323, 2588, 4686, 5504, 610, 1884, 85, 1216, 382, 22, 13, 3070, 19357, 19338, 288, 25, 1019, 256, 481, 20552, 264, 6016, 517, 311, 279, 26967, 323, 1008, 12188, 6813, 624, 256, 481, 29558, 3720, 7722, 18223, 421, 4362, 382, 334, 29019, 25704, 25, 56177, 12, 5547, 304, 279, 27501, 476, 264, 7010, 11, 1632, 12, 684, 92483, 19277, 311, 5358, 3871, 504, 45379, 2238, 6157, 323, 9380, 48132, 624, 12, 5443, 8651, 42112, 369, 27686, 323, 96039, 287, 311, 5648, 34422, 1095, 279, 4240, 3015, 448, 25284, 624, 12, 23240, 990, 21700, 11, 2477, 38030, 12784, 533, 7236, 323, 74865, 11689, 6188, 369, 39408, 990, 311, 5648, 33346, 26431, 624, 12, 14561, 304, 264, 36438, 1616, 311, 5978, 498, 1513, 944, 3116, 894, 18773, 382, 1359, 2701, 1493, 7354, 11, 498, 3278, 2968, 697, 1803, 264, 17423, 4240, 429, 537, 1172, 3643, 432, 1401, 2244, 714, 1083, 8609, 304, 20337, 1181, 897, 13, 19881, 11, 5792, 44193, 646, 5358, 9850, 323, 17576, 323, 2506, 697, 1803, 3330, 501, 369, 1635, 311, 2525, 13, 151645, 198, 151645], 
'chosen_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
}

Next, let's tokenize just the prompt:

tokenizer(dataset[0]['prompt'])
{'input_ids': [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198, 151644, 872, 198, 4340, 653, 358, 7716, 264, 1803, 30, 151645, 198], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

Now, let's see if the sequence repeats.

full_tokens = tokenized_dataset[0]['chosen_input_ids']
prompt_tokens = tokenizer(dataset[0]['prompt'])['input_ids']
# fancy math to check overlap if you want
# or just decode
tokenizer.decode(full_tokens)

So yeah, looks wrong.

"<|im_end|><|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHow do I detail a car?<|im_end|>\n<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHow do I detail a car?<|im_end|>\n<|im_start|>assistant\nDetailing a car involves a thorough cleaning inside and out, as well as polishing and waxing to protect the vehicle's surfaces. Here's a step-by-step guide to detailing a car:\n\n**Exterior Detailing:**\n\n1. **Wash the Car:**\n   - Rinse the car with water to remove loose dirt.\n   - Use a car wash soap and microfiber wash mitt to clean the car from top to bottom.\n   - Clean the wheels and tires with a brush and a wheel cleaner.\n   - Rinse the car thoroughly to remove all soap.\n\n2. **Dry the Car:**\n   - Use a microfiber towel or a chamois to dry the car to prevent water spots.\n\n3. **Clay Bar Treatment:**\n   - Use a clay bar with a lubricant to remove embedded surface contaminants from the paint.\n\n4. **Polishing:**\n   - Apply car polish with a dual-action polisher or by hand to correct paint imperfections and create a smooth surface.\n\n5. **Waxing:**\n   - Apply a coat of wax or paint sealant to protect the paint and give it a glossy finish.\n\n6. **Windows and Mirrors:**\n   - Clean the windows and mirrors with a glass cleaner and a microfiber towel.\n\n7. **Tire and Trim Dressing:**\n   - Apply a tire dressing to the tires for a shiny finish.\n   - Use a trim restorer or protectant on plastic and rubber parts to prevent fading.\n\n**Interior Detailing:**\n\n1. **Remove Trash:**\n   - Clear out any trash and remove personal items from the car.\n\n2. **Vacuum:**\n   - Vacuum the seats, carpets, floor mats, and trunk.\n   - Use a brush attachment for the dashboard and door panels.\n\n3. **Shampoo Carpets and Upholstery:**\n   - Use a carpet cleaner and a brush to clean the carpets and upholstery.\n   - For leather interiors, use a leather cleaner and conditioner.\n\n4. **Clean Hard Surfaces:**\n   - Wipe down all hard surfaces (dashboard, center console, door panels, etc.) with a mild all-purpose cleaner and a microfiber cloth.\n\n5. **Windows and Mirrors:**\n   - Clean the interior side of windows and mirrors.\n\n6. **Air Vents and Crevices:**\n   - Use a detailing brush or compressed air to clean out air vents and hard-to-reach crevices.\n\n7. **Final Touches:**\n   - Apply a protectant to the dashboard and other plastic components.\n   - Replace air fresheners if needed.\n\n**Additional Tips:**\n\n- Work in the shade or a cool, well-ventilated garage to prevent products from drying too quickly and leaving residue.\n- Use separate buckets for washing and rinsing to avoid contaminating the clean water with dirt.\n- Always use gentle, non-abrasive materials and cleaners specifically designed for automotive use to avoid damaging surfaces.\n- Move in a systematic way to ensure you don't miss any spots.\n\nBy following these steps, you'll give your car a thorough clean that not only makes it look great but also helps in maintaining its value. Remember, regular detailing can prevent wear and tear and keep your car looking new for years to come.<|im_end|>\n<|im_end|>"

Am submitting a fix and re-running all the evals. Thankfully, with DPO math, the repeated tokens in the logprobs shouldn't have an issue on the computation very much. We'll confirm soon!