long8v / PTIR

Paper Today I Read
19 stars 0 forks source link

[169] Direct Preference Optimization: Your Language Model is Secretly a Reward Model #188

Open long8v opened 3 weeks ago

long8v commented 3 weeks ago
image

paper

TL;DR

Details

Preliminaries

이걸 binary 문제로 치환하면

image

DPO

위의 함수를 다시 쓰면

image image

partition function은 확률분포로 만들어주는 역할?

optimal policy에 대해 bradely-terry model은 아래와 같은 preferenc가 성립

image

policy의 관점에서 human preference data를 가지고 있으니 이를 mle objective로 표현하면

image

what does the DPO updates?

image
long8v commented 2 weeks ago
image

https://github.com/huggingface/trl/blob/main/trl/trainer/dpo_trainer.py#L1100

            pi_logratios = policy_chosen_logps - policy_rejected_logps
            if self.reference_free:
                ref_logratios = torch.tensor([0], dtype=pi_logratios.dtype, device=pi_logratios.device)
            else:
                ref_logratios = reference_chosen_logps - reference_rejected_logps

            pi_logratios = pi_logratios.to(self.accelerator.device)
            ref_logratios = ref_logratios.to(self.accelerator.device)
            logits = pi_logratios - ref_logratios

            if self.f_divergence_type == FDivergenceType.JS_DIVERGENCE.value:
                # The js-divergence formula: log(2 * u / (1 + u))
                # The divergence difference between the chosen and rejected sample is:
                #     log(2 * u[w] / (1 + u[w])) - log(2 * u[l] / (1 + u[l]))
                #       = log(u[w]) - log(u[l]) - (log(1 + u[w]) - log(1 + u[l]))
                # where u[w] and u[l] are the policy/reference probability ratios
                # for the chosen and rejected samples, respectively.
                logits -= F.softplus(chosen_logratios) - F.softplus(rejected_logratios)

        # The beta is a temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5.
        # We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
        # calculates a conservative DPO loss.
        if self.loss_type == "sigmoid":
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )