Open Vance0124 opened 7 months ago
My main issue is:
FSDPTrainer
will have no effect. The reference wandb run used the command:
train.py model=pythia28 datasets=[hh] loss=sft exp_name=pythia28_hh_sft_bf16 gradient_accumulation_steps=2 batch_size=64 n_epochs=1 eval_batch_size=32 trainer=FSDPTrainer eval_every=5000 sample_during_eval=false model.fsdp_policy_mp=bfloat16
TensorParallelTrainer
and reduced precision. Can you try again with the FSDPTrainer
?
I replicated the experiments of pythia28 on hh (Anthropic/hh-rlhf) using the open-source code. Here are some of the experimental results:
SFT1:
with the
But the result doesn't seem very promising.
I also implemented three other versions:
batch_size=4
. Then I evaluate the model using GPT-4:SFT2:
The result:
Which also doesn't seem to be good.
SFT3(bfloat16):
The result:![sft1](https://github.com/eric-mitchell/direct-preference-optimization/assets/56477668/23345daf-f6a1-4dba-94cb-91884df229fe)
SFT4(bfloat16):
The result:![sft2](https://github.com/eric-mitchell/direct-preference-optimization/assets/56477668/98afad75-f8e9-43c1-a5db-bb07a28b5b7e)
The evaluation of SFT3(bfloat16) and SFT4(bfloat16) seem to be even worse.
Based on The SFT1 (which I think is probably the best among these), I trained the DPO1 with the following command:
The result:
In which the highest winning rate is only 50%,but I don't what I did wrong or missed something.
I also implemented another result for DPO based on SFT2: DPO2:
Compared to the training results on the open WandB, the experimental results I ran myself did not meet expectations.
The training results "rewards_train/accuracies" on the open WandB:
While the training results "rewards_train/accuracies" of mine:
The "rewards_train/accuracies" on the open WandB can even reach more than 70%,but the "rewards_train/accuracies" of mine could only achieve around 60% at most.
And the evaluating results:
"rewards_eval/accuracies" on the open WandB:
The "rewards_eval/accuracies" of mine:
The eval result also have a gap (getting close to 10%).
Other parameters use the default parameters (e.g.
lr=5e-7
). I'm not sure if I made a mistake somewhere or what needs to be modified to bridge this gap. Please help me, and I sincerely appreciate it.