Open long8v opened 2 months ago
https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L1100
pi_logratios = policy_chosen_logps - policy_rejected_logps
if self.reference_free:
ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
else:
ref_logratios = reference_chosen_logps - reference_rejected_logps
pi_logratios = pi_logratios.to(self.accelerator.device)
ref_logratios = ref_logratios.to(self.accelerator.device)
logits = pi_logratios - ref_logratios
if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
# The js-divergence formula: log(2 * u / (1 + u))
# The divergence difference between the chosen and rejected sample is:
# log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l]))
# = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l]))
# where u[w] and u[l] are the policy/reference probability ratios
# for the chosen and rejected samples, respectively.
logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)
# The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
# calculates a conservative DPO loss.
if self.loss_type == "sigmoid":
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
pi_logratios
: $log(\frac{\pi_\theta (yw|x)}{\pi\theta (y_l|x)})$ref_logratios
: $log(\frac{\pi_{ref} (yw|x)}{\pi{ref} (y_l|x)})$logits
(=pi_logratios - ref_logratios
) : $log(\frac{\pi_\theta (yw|x)}{\pi\theta (yl|x)}) - log(\frac{\pi{ref} (yw|x)}{\pi{ref} (y_l|x)})$ log에서 분모 정리하고 하면 위에서 구한 Loss처럼 나옴
paper
TL;DR
Details
Preliminaries
SFT 소량의 양질의 데이터를 사용해서 $\pi^{SFT}$를 만듦
Reward modeling (Bradley-Terry model)
이걸 binary 문제로 치환하면
DPO
위의 함수를 다시 쓰면
partition function은 확률분포로 만들어주는 역할?
optimal policy에 대해 bradely-terry model은 아래와 같은 preferenc가 성립
policy의 관점에서 human preference data를 가지고 있으니 이를 mle objective로 표현하면
what does the DPO updates?