OpenLLMAI / OpenRLHF

An Easy-to-use, Scalable and High-performance RLHF Framework (70B+ PPO Full Tuning & Iterative DPO & LoRA & Mixtral)
https://openrlhf.readthedocs.io/
Apache License 2.0
1.73k stars 164 forks source link

Implement KTO into OpenRLHF #201

Closed Dylancer1998 closed 5 months ago

Dylancer1998 commented 5 months ago

Referenced the implementation of HALOs, the KTO algorithm has been integrated into this branch. It supports both balanced (referred to as the vanilla version) and unbalanced (referred to as the non-vanilla version) scenarios for handling positive and negative samples in a batch. The vanilla version ensures that the number of positive and negative samples is consistent within each batch, while the non-vanilla version does not require this consistency.

A lightweight dataset was selected for algorithm validation, where the effects of DPO, vanilla KTO, non-vanilla KTO, and the baseline were compared. The dataset and the results are as follows:

model Writing Roleplay Reasoning Math Coding Extraction STEM Humanities Average
baseline 7.125 7.425 4.05 2.6 2.85 4.475 7.475 8.475 5.559
DPO 7.4 7.39 3.9 3.05 2.475 4.875 7.2 9.075 5.670
KTO_with_vanilla_loss 7.225 7.325 4.025 2.3 3.475 5.525 7.184 9.075 5.715
KTO 7.145 7.273 4.112 2.666 2.790 5.212 8.315 8.479 5.799

MTBench

* baseline model is "OpenLLMAI/Llama-2-7b-sft-model-ocra-500k"

hijkzzz commented 5 months ago

Thank you for your contribution and we will review it as soon as possible