eric-mitchell / direct-preference-optimization

Reference implementation for DPO (Direct Preference Optimization)
Apache License 2.0
2.06k stars 167 forks source link

Llama-2-13b-chat Valid reward accuracy remains ~50% #53

Open nxphi47 opened 11 months ago

nxphi47 commented 11 months ago

Hi, great work! I am having some issues with training Llama-2-13b-chat on Anthropic HH dataset.

I followed to train SFT on HH and then DPO according to README

The only things I change is policy_dtype: bfloat16 to use Flash attention V2 and change the tokenization so that it match llama-2 instruction-following format. Here are examples of the tokens

# Note: this is individual token ids converted back to tokens, not detokenized sequence
||||| chosen_seq:
<s> [ INST ] ▁What ▁are ▁some ▁good ▁young ▁adult ▁nov els ▁I ▁could ▁buy ▁for ▁my ▁daughter ? ▁[ / INST ] ▁Have ▁you ▁read ▁any ▁of ▁Ver onica ▁Roth ’ s ▁books ? </s> <s> [ INST ] ▁I ▁have ▁not . ▁Does ▁she ▁write ▁young ▁adult ▁nov els ? ▁[ / INST ] ▁Yes , ▁she ▁writes ▁D iver gent ▁and ▁Ins urg ent . ▁They ’ re ▁both ▁pretty ▁good . ▁Also ▁The ▁G lass ▁Castle ▁is ▁really ▁good . </s> <s> [ INST ] ▁S ounds ▁good . ▁Any ▁other ▁suggestions ? ▁[ / INST ] ▁ ▁The ▁Book ▁Th ief ▁is ▁really ▁popular ▁and ▁the ▁movie ▁was ▁also ▁good . ▁And ▁I ▁also ▁really ▁liked ▁The ▁Hung er ▁Games ▁series . </s> <pad> <pad>  ...
||||| rejected_seq:
<s> [ INST ] ▁What ▁are ▁some ▁good ▁young ▁adult ▁nov els ▁I ▁could ▁buy ▁for ▁my ▁daughter ? ▁[ / INST ] ▁Have ▁you ▁read ▁any ▁of ▁Ver onica ▁Roth ’ s ▁books ? </s> <s> [ INST ] ▁I ▁have ▁not . ▁Does ▁she ▁write ▁young ▁adult ▁nov els ? ▁[ / INST ] ▁Yes , ▁she ▁writes ▁D iver gent ▁and ▁Ins urg ent . ▁They ’ re ▁both ▁pretty ▁good . ▁Also ▁The ▁G lass ▁Castle ▁is ▁really ▁good . </s> <s> [ INST ] ▁S ounds ▁good . ▁Any ▁other ▁suggestions ? ▁[ / INST ] ▁ ▁The ▁Hung er ▁Games ▁is ▁a ▁big ▁hit ▁right ▁now . ▁It ’ s ▁kind ▁of ▁like ▁a ▁young ▁adult ▁thr iller ▁set ▁in ▁a ▁d yst op ian ▁society . ▁The ▁Maz e ▁Run ner ▁is ▁also ▁pretty ▁good . ▁Both ▁of ▁these ▁books ▁are ▁by ▁Su z anne ▁Collins . </s> <pad> <pad>  ...

However, I found the reward accuracies are not better than 50%, see below, and comparison performances are worse than before (evaluated by GPT-4)

wandb: Run history:
wandb:        counters/examples ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
wandb:         counters/updates ▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
wandb:      examples_per_second ▁▄▆▄▃▄▂▆▃▃▅▅▃▄▄▂▄▂▂▆▃▃▄▄▄▃▄▂▃█▄▄▃▃▂▄▂▂▅▂
wandb:                grad_norm ▂▃▇▄▁▃▃▄▁▁▃▃▅▁▁▇▂▃▃▃▃▂▅▁▂▂▄▃▂▂▁▂▂▃▁▃█▁▆▄
wandb:        logps_eval/chosen █▅▄▃▄▃▂▂▁
wandb:      logps_eval/rejected █▅▄▃▄▃▂▂▁
wandb:       logps_train/chosen ▄▅▄▅█▅▆▃▄▅▄▁▅▂▃▄▆▄▇▄▃▅▄▄▅▅▄▂▅▆▆▃▃▂▄▆▂▄▆▃
wandb:     logps_train/rejected ▄▅▂█▅▄▅▄▃▆▅▂▂▂▆▆▅▃▅▄▄▁▃▇▅▅▃▃▅▅▂▃▂▂▄▄▃▄▄▂
wandb:                loss/eval █▅▄▄▅▃▂▁▁
wandb:               loss/train ▇▆█▅▇▅▅▅▅▆▅▅█▅▅▅▄▃▆▃▅▅▆▃▄▁▄▅▃▂▄▅▆▂▄▆█▅▃▄
wandb:  rewards_eval/accuracies ▁████████
wandb:      rewards_eval/chosen █▅▄▃▄▃▂▂▁
wandb:     rewards_eval/margins ▁▄▅▅▅▇▇██
wandb:    rewards_eval/rejected █▅▄▃▄▃▂▂▁
wandb: rewards_train/accuracies ▁▄▁▆▃▆▄▄▆▄▅▆▄▆▆▅█▇▃▆▆▆▅▇▄▅▃▆▆▇▅▇▄█▆▆▇▄▆▇
wandb:     rewards_train/chosen ██▇▇▇▆▆▇▆▅▆▇▅▆▅▅▆▆▄▆▄▅▄▄▄▃▄▁▄▅▂▃▃▄▃▂▃▄▂▄
wandb:    rewards_train/margins ▂▃▁▄▂▃▃▃▄▂▃▄▁▄▄▄▄▅▃▆▄▄▃▆▅█▅▅▆▇▅▅▃█▆▃▂▄█▆
wandb:   rewards_train/rejected ██▇▇▇▆▆▇▆▅▆▆▅▆▅▅▅▆▅▅▄▄▄▃▄▂▃▁▃▄▂▃▃▃▃▃▄▃▁▃

wandb:        counters/examples 160704
wandb:         counters/updates 2511
wandb:      examples_per_second 4.53392
wandb:                grad_norm 63.75
wandb:        logps_eval/chosen -143.79242
wandb:      logps_eval/rejected -130.69831
wandb:       logps_train/chosen -127.1315
wandb:     logps_train/rejected -141.18178
wandb:                loss/eval 0.6763
wandb:               loss/train 0.69097
wandb:  rewards_eval/accuracies 0.53516
wandb:      rewards_eval/chosen -0.25887
wandb:     rewards_eval/margins 0.05858
wandb:    rewards_eval/rejected -0.31745
wandb: rewards_train/accuracies 0.48438
wandb:     rewards_train/chosen -0.31006
wandb:    rewards_train/margins 0.02927
wandb:   rewards_train/rejected -0.33933

As our system cannot access public Wandb, so I don't have wandb link or better metric indications to diagnose.