This repo allows you to design new Human-Aware Loss Functions (HALOs) for aligning LLMs with offline human feedback at scale (read more in our technical report or our full paper). It was used to create Archangel, the largest-ever suite of human-feedback-aligned LLMs, and has been tested at scales from 1B to 30B.
This repo draws from the excellently written DPO repo and has preserved many design choices from the original. Some of the key changes we introduced are:
To first SFT a model, run a command like
python train.py loss=sft model=llama7b datasets=[shp,hh,oasst] exp_name=llama7b_sft mode=train ++cache_dir=/data/models
which will save a model to /data/models/llama7b_sft/LATEST/policy.pt
. To then align a model with KTO, run a command like
python train.py loss=kto model=llama7b datasets=[shp,hh,oasst] exp_name=llama7b_kto mode=train ++cache_dir=/data/models ++model.load_from=llama7b_sft/LATEST/policy.pt
which will save a model to /data/models/llama7b_kto/LATEST/policy.pt
.
Let's say we want to implement a new HALO called Kahneman-Tversky optimization (KTO). This is already implemented in this repo based on the details in our report, but let's pretend that it's not. What should we do?
First, create and activate the conda environment.
conda env create -f environment.yml
conda activate halos
If you can't create a conda environment, or you face some issue during installtion, try doing
conda create -n halos3 python=3.10.12
pip3 install numpy==1.24.3 ninja==1.11.1.1 packaging==23.1
conda install pytorch==2.1.1 pytorch-cuda=12.1 -c pytorch -c nvidia
pip3 install flash-attn==2.3.3
pip3 install transformers==4.35.2 datasets hydra-core==1.3.2 wandb==0.15.3 openai==1.6.1 accelerate==0.21.0 tensor-parallel==1.2.4
Determine whether you need a new dataset. If you have a dataset called foo
, add a function called get_foo
to dataloader.py
that will return a Dataset
instance. This function should have the following signature, where the prefixes and suffixes determine how the dataset is formatted (see config.yaml
) and split
should be either train
or test
:
def get_foo(split: str, human_prefix: str, human_suffix: str, assistant_prefix: str, assistant_suffix: str) -> Dataset:
Determine whether you need a new dataloader. KTO doesn't use preference pairs, just knowledge of whether outputs are desirable or undesirable.
This means we use dataloader.UnpairedPreferenceDataLoader
. However, that dataloader assumes that you're working with datasets that originally contain preference pairs, like Anthropic HH or SHP.
If you wanted a custom dataloader, you would implement it in the same Python file by extending the base DataLoader
class.
Write a trainer in trainers.py
. This should subclass either UnpairedPreferenceTrainer
or PairedPreferenceTrainer
depending on whether it uses pairs of preferences or not.
If you need highly custom behavior that is not in either, then you can subclass BasicTrainer
directly.
We can implement a simple version of KTO as follows (note that this is different from the proper version of KTO in KTOTrainer
, which does not assume the existence of both chosen and rejected examples in each batch).
To make SimpleKTOTrainer, we just subclass trainers.UnpairedPreferenceTrainer
as trainers.SimpleKTOTrainer
and overwrite the loss function definition. KTO has one hyperparameter, beta, which we can access via self.config.loss.beta
:
class SimpleKTOTrainer(UnpairedPreferenceTrainer):
"""A simple version of KTO meant to introduce you to the HALOs repo."""
def loss(self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""Compute the Kahneman-Tversky loss for a batch of policy and reference model log probabilities.
For each batch of n/2 chosen examples and n/2 rejected examples (belonging to n different inputs), calculate the loss as follows.
If generation y ~ p_chosen, where x' ~ are the examples with rejected generations, we have the 'chosen' loss:
L(x, y) := 1 - sigmoid(beta * (log p_policy(y|x) - log p_reference(y|x) - KL(p_policy(y_rejected|x') || p_reference(y_rejected|x')))
If generation y ~ p_rejected, , where x' ~ are the examples with chosen generations, we have the 'rejected' loss:
L(x, y) := 1 - sigmoid(beta * KL(p_policy(y_chosen|x') || p_reference(y_chosen|x')) - [log p_policy(y|x) - log p_reference(y|x)])
"""
# your implementation goes here
return losses, chosen_rewards, rejected_rewards
Add a file to the config/loss folder specifying the details of the loss:
name: kto-simple
beta: 0.1 # the temperature parameter for simple KTO; lower values mean we care less about the reference model
trainer: SimpleKTOTrainer # implemented in trainers.py
dataloader: UnpairedPreferenceDataLoader # already exists in dataloaders.py
use_reference_model: true # true because the loss definition includes a reference model
Now we can start training a model! Let's train a Llama-7B model on the SHP, Anthropic HH, and Open Assistant datasets. Since the corresponding entry for Llama-7B is config/model/llama7b.yaml, we run a command with Hydra:
python train.py loss=kto-simple model=llama7b datasets=[shp,hh,oasst] exp_name=kto-simple_llama7b mode=train ++cache_dir=/data/models
which will align a Llama-7B model from scratch. If we want to align a model that we've already finetuned with the HALOs repo,
we can add something like ++model.load_from=/data/models/sft_llama7b/LATEST/policy.pt
to the end of the command.
That's it! Your model will be saved to /data/models/kto-simple_llama7b/LATEST/policy.pt
.
Let's sample some generations from our newly trained model. The sampling configs are in either config/config.yaml
or under models/
.
We can sample 512 generations from our newly trained model in batches of 32 with the command, which will create a JSON file under samples/{config.exp_name}.json
.
python eval.py --config-path=/data/models/kto-simple_llama7b/config.yaml ++mode=sample ++n_samples=512 ++model.eval_batch_size=32 ++samples_dir=samples/
After setting OPENAI_API_KEY
, we can evaluate our aligned model with GPT-4 with the following command, which compares the aligned model's generations to the human-chosen response in the data:
python compare.py -f samples/kto-simple_llama7b.json -mc 512 -bk chosen -ck policy -r result.jsonl
Do you support multi-node training?
No, currently the repo only supports single-node training. Multi-node training will be added at some point in the future. Every model in the Archangel suite was trained with 8 x A100 GPUs on a single node.
How do I save intermediate checkpoints?
Set intermediate_checkpoints to true in config/config.yaml or on the command line with ++config.intermediate_checkpoints=true. Every config.eval_every steps, a checkpoint will be saved in the experiment directory ($cache_dir/$exp_name).
Where do I find all the Archangel models?
They are all on the Huggingface Hub:
Model | PPO | DPO | KTO | SFT | SLIC | SFT+PPO | SFT+DPO | SFT+KTO | CSFT | SFT+CSFT |
---|---|---|---|---|---|---|---|---|---|---|
pythia1-4b | weights | weights | weights | weights | weights | weights | weights | weights | weights | weights |
pythia2-8b | weights | weights | weights | weights | weights | weights | weights | weights | weights | weights |
pythia6-9b | weights | weights | weights | weights | weights | weights | weights | weights | weights | weights |
pythia12-0b | weights | weights | weights | weights | weights | weights | weights | weights | weights | weights |
llama7b | weights | weights | weights | weights | weights | weights | weights | weights | weights | weights |
llama13b | weights | weights | weights | weights | weights | weights | weights | weights | weights | weights |
llama30b | weights | weights | weights | weights | weights | weights | weights | weights | weights | weights |
If you find this repo or the technical paper useful in your research, please feel free to cite our work:
@techreport{ethayarajh2023halos,
author = {Ethayarajh, Kawin and Xu, Winnie, and Jurafsky, Dan and Kiela, Douwe},
title = {Human-Aware Loss Functions (HALOs)},
institution = {Contextual AI},
note = {https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf},
year = {2023},
}