axolotl-ai-cloud / axolotl

Go ahead and axolotl questions
https://axolotl-ai-cloud.github.io/axolotl/
Apache License 2.0
7.48k stars 808 forks source link

Implements SPPO Alignment Algoritm #1735

Open kaykyr opened 1 month ago

kaykyr commented 1 month ago

Implements SPPO Alignment Algorithm

Description

This pull request implements the Self-Play Preference Optimization (SPPO) algorithm for language model alignment. The SPPO algorithm, as described in the paper "Self-Play Preference Optimization for Language Model Alignment" (available at https://arxiv.org/abs/2405.00675), uses a self-play mechanism to optimize language models based on preference probabilities. This implementation leverages the code from the original repository at https://github.com/uclaml/SPPO and integrates it into the Axolotl framework.

Motivation and Context

This change is required to improve the alignment of language models with human preferences, addressing issues of reliability, safety, and ethical considerations in language model outputs. The SPPO algorithm provides a more flexible and accurate method for preference optimization compared to traditional reinforcement learning approaches.

How has this been tested?

The implementation has been tested using a variety of prompts from the UltraFeedback dataset, evaluating the model's performance on AlpacaEval 2.0 and MT-Bench. The tests involved assessing the log-likelihood of chosen responses and comparing the model's win rates against state-of-the-art models, ensuring that the changes do not adversely affect other areas of the codebase.

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

GitHub: @kaykyr HuggingFace: https://huggingface.co/kaykyramos Discord: kaykyramos

kaykyr commented 1 month ago

hi @kaykyr, Thanks for submitting this technique. I'd love to see this integrated into axolotl, but my main concern is the amount of duplicated code we're going to have to maintain. I'm happy to help refactor the pieces in the trainer_builder, but I think it would be ideal if we could extract the necessary SPPO changes from DPOTrainer so we have a smaller footprint to maintain.

re: tests, would be good to have some tests to spot check the functionality. I'm happy to help with this as well, where we setup some e2e tests that run a small model for about 10-20 steps to verify that the trainer works.

Hey @winglian, I'll do my best to submit a better pull request doing a better approach to SPPO Integration.

kaykyr commented 1 month ago

hi @kaykyr, Thanks for submitting this technique. I'd love to see this integrated into axolotl, but my main concern is the amount of duplicated code we're going to have to maintain. I'm happy to help refactor the pieces in the trainer_builder, but I think it would be ideal if we could extract the necessary SPPO changes from DPOTrainer so we have a smaller footprint to maintain.

re: tests, would be good to have some tests to spot check the functionality. I'm happy to help with this as well, where we setup some e2e tests that run a small model for about 10-20 steps to verify that the trainer works.

I am also running 3 iterations and I'll upload the result models to hugging face for comparison... At this momento I am running the iter2 on my homelab.

{'loss': 0.4766, 'grad_norm': 5.125, 'learning_rate': 0.00015594912061278626, 'rewards/chosen': -1.9593877792358398, 'rewards/rejected': -3.4484448432922363, 'rewards/accuracies': 0.5, 'rewards/margins': 1.4890570640563965, 'logps/rejected': -156.06246948242188, 'logps/chosen': -163.50051879882812, 'logits/rejected': -0.08261816203594208, 'logits/chosen': 0.01163027435541153, 'epoch': 1.0}
{'loss': 0.1187, 'grad_norm': 0.85546875, 'learning_rate': 0.000155500908021347, 'rewards/chosen': -1.4073667526245117, 'rewards/rejected': -7.409327983856201, 'rewards/accuracies': 1.0, 'rewards/margins': 6.0019612312316895, 'logps/rejected': -302.47320556640625, 'logps/chosen': -217.1146697998047, 'logits/rejected': -0.17893444001674652, 'logits/chosen': -0.3086361289024353, 'epoch': 1.01}
{'loss': 0.2105, 'grad_norm': 3.84375, 'learning_rate': 0.00015505107827058036, 'rewards/chosen': -2.149808645248413, 'rewards/rejected': -5.619155406951904, 'rewards/accuracies': 0.75, 'rewards/margins': 3.469346761703491, 'logps/rejected': -212.09442138671875, 'logps/chosen': -163.01132202148438, 'logits/rejected': -0.03775382041931152, 'logits/chosen': -0.1259421706199646, 'epoch': 1.01}
{'loss': 0.1979, 'grad_norm': 1.2890625, 'learning_rate': 0.00015459964446741382, 'rewards/chosen': -0.6550925374031067, 'rewards/rejected': -4.351650714874268, 'rewards/accuracies': 0.625, 'rewards/margins': 3.6965579986572266, 'logps/rejected': -196.5653839111328, 'logps/chosen': -149.3184814453125, 'logits/rejected': -0.23472216725349426, 'logits/chosen': -0.27307409048080444, 'epoch': 1.02}
{'loss': 0.2405, 'grad_norm': 1.984375, 'learning_rate': 0.00015414661976551302, 'rewards/chosen': -0.46041208505630493, 'rewards/rejected': -5.224635601043701, 'rewards/accuracies': 0.75, 'rewards/margins': 4.764223098754883, 'logps/rejected': -244.72647094726562, 'logps/chosen': -190.26235961914062, 'logits/rejected': -0.25164929032325745, 'logits/chosen': -0.09058046340942383, 'epoch': 1.02}
{'loss': 0.0668, 'grad_norm': 1.984375, 'learning_rate': 0.0001536920173648984, 'rewards/chosen': -2.3507869243621826, 'rewards/rejected': -6.221563816070557, 'rewards/accuracies': 1.0, 'rewards/margins': 3.870777130126953, 'logps/rejected': -314.01568603515625, 'logps/chosen': -270.7308044433594, 'logits/rejected': 0.017346393316984177, 'logits/chosen': 0.05637218803167343, 'epoch': 1.03}
{'loss': 0.227, 'grad_norm': 1.8125, 'learning_rate': 0.0001532358505115607, 'rewards/chosen': -0.6020192503929138, 'rewards/rejected': -5.3337531089782715, 'rewards/accuracies': 0.875, 'rewards/margins': 4.731733798980713, 'logps/rejected': -214.8163299560547, 'logps/chosen': -182.3211212158203, 'logits/rejected': -0.021296918392181396, 'logits/chosen': 0.01770481839776039, 'epoch': 1.03}
{'loss': 0.1773, 'grad_norm': 2.6875, 'learning_rate': 0.00015277813249707487, 'rewards/chosen': -1.7499217987060547, 'rewards/rejected': -5.255713939666748, 'rewards/accuracies': 1.0, 'rewards/margins': 3.5057921409606934, 'logps/rejected': -310.7301330566406, 'logps/chosen': -271.0190734863281, 'logits/rejected': 0.045894794166088104, 'logits/chosen': -0.04848558083176613, 'epoch': 1.04}
{'loss': 0.4042, 'grad_norm': 3.546875, 'learning_rate': 0.000152318876658213, 'rewards/chosen': -1.4743281602859497, 'rewards/rejected': -4.679446697235107, 'rewards/accuracies': 0.875, 'rewards/margins': 3.205118179321289, 'logps/rejected': -271.888671875, 'logps/chosen': -239.9329833984375, 'logits/rejected': 0.010095290839672089, 'logits/chosen': -0.08229245990514755, 'epoch': 1.04}
{'loss': 0.2753, 'grad_norm': 2.671875, 'learning_rate': 0.0001518580963765555, 'rewards/chosen': -3.2833800315856934, 'rewards/rejected': -6.391260147094727, 'rewards/accuracies': 0.875, 'rewards/margins': 3.1078805923461914, 'logps/rejected': -283.3193054199219, 'logps/chosen': -248.41119384765625, 'logits/rejected': -0.036375753581523895, 'logits/chosen': -0.032573096454143524, 'epoch': 1.05}
{'loss': 0.0712, 'grad_norm': 1.40625, 'learning_rate': 0.00015139580507810119, 'rewards/chosen': -0.22441568970680237, 'rewards/rejected': -4.1824116706848145, 'rewards/accuracies': 1.0, 'rewards/margins': 3.957995891571045, 'logps/rejected': -218.41976928710938, 'logps/chosen': -175.13992309570312, 'logits/rejected': 0.10275314003229141, 'logits/chosen': 0.039625994861125946, 'epoch': 1.05}
{'loss': 0.066, 'grad_norm': 1.984375, 'learning_rate': 0.00015093201623287631, 'rewards/chosen': -1.5608439445495605, 'rewards/rejected': -7.547214508056641, 'rewards/accuracies': 1.0, 'rewards/margins': 5.98637056350708, 'logps/rejected': -326.3816223144531, 'logps/chosen': -270.4262390136719, 'logits/rejected': -0.08117158710956573, 'logits/chosen': -0.03967729210853577, 'epoch': 1.06}
{'loss': 0.1853, 'grad_norm': 2.515625, 'learning_rate': 0.0001504667433545419, 'rewards/chosen': -0.02940535545349121, 'rewards/rejected': -5.676110744476318, 'rewards/accuracies': 0.875, 'rewards/margins': 5.646705150604248, 'logps/rejected': -262.2330017089844, 'logps/chosen': -205.32958984375, 'logits/rejected': -0.13373667001724243, 'logits/chosen': -0.18218691647052765, 'epoch': 1.06}
{'loss': 0.1211, 'grad_norm': 1.2265625, 'learning_rate': 0.00015000000000000001, 'rewards/chosen': -1.7262533903121948, 'rewards/rejected': -5.780290603637695, 'rewards/accuracies': 1.0, 'rewards/margins': 4.054037094116211, 'logps/rejected': -259.7177734375, 'logps/chosen': -185.06295776367188, 'logits/rejected': -0.01202734187245369, 'logits/chosen': -0.19560036063194275, 'epoch': 1.07}
{'loss': 0.1652, 'grad_norm': 1.796875, 'learning_rate': 0.00014953179976899878, 'rewards/chosen': -2.2642948627471924, 'rewards/rejected': -6.6951165199279785, 'rewards/accuracies': 0.75, 'rewards/margins': 4.430821418762207, 'logps/rejected': -271.29461669921875, 'logps/chosen': -222.35409545898438, 'logits/rejected': -0.15404176712036133, 'logits/chosen': -0.19560746848583221, 'epoch': 1.07}
 36%|█████████████████████████████████████████████████████████████▍                                                                                                              | 215/602 [1:53:26<3:32:56, 33.01s/it]