ml-explore / mlx-examples

Examples in the MLX framework
MIT License
5.81k stars 826 forks source link

Reinforcement Learning from Human Feedback (RLHF) examples: Direct Preference Optimization (DPO) #513

Open danilopeixoto opened 6 months ago

danilopeixoto commented 6 months ago

Introduce one Reinforcement Learning from Human Feedback (RLHF) example, such as Direct Preference Optimization (DPO) method.

Paper

Direct Preference Optimization: Your Language Model is Secretly a Reward Model

Notes

Direct Preference Optimization (DPO): A Simplified Explanation by João Lages

Implementation examples

Possible MLX implementation

Policy and reference log probabilities:

def get_batched_logps(model, inputs, targets):
    logits, _ = model(inputs)
    logits = logits.astype(mx.float32)

    loss_mask = targets != 0
    per_token_logps = mx.take_along_axis(nn.log_softmax(logits), targets[..., None], axis=2).squeeze(2)

    return tuple((per_token_logps * loss_mask).sum(-1).split(2))

Loss:

def dpo_loss(model, beta, label_smoothing, reference_chosen_logps, reference_rejected_logps, inputs, targets):
    chosen_logps, rejected_logps = get_batched_logps(model, inputs, targets)

    pi_logratios = chosen_logps - rejected_logps
    reference_logratios = reference_chosen_logps - reference_rejected_logps

    logits = pi_logratios - reference_logratios
    losses = -nn.log_sigmoid(beta * logits) * (1.0 - label_smoothing) - nn.log_sigmoid(-beta * logits) * label_smoothing

    chosen_rewards = beta * (chosen_logps - reference_chosen_logps)
    rejected_rewards = beta * (rejected_logps - reference_rejected_logps)
    reward_accuracies = (chosen_rewards > rejected_rewards).astype(mx.float32)
    reward_margins = chosen_rewards - rejected_rewards

    ntoks = (inputs != 0).sum()

    return (
        losses.mean(),
        chosen_rewards.mean(),
        rejected_rewards.mean(),
        reward_accuracies.mean(),
        reward_margins.mean(),
        ntoks,
    )

Beta: The temperature parameter for the DPO loss is typically set in the range of 0.1 to 0.5. The reference model is ignored when beta equals 0.

Label smoothing: This parameter represents the conservativeness for DPO loss, assuming that preferences are noisy and can be flipped with a probability of label_smoothing.

Note label_smoothing > 0 defines the Conservative DPO loss.

awni commented 6 months ago

@danilopeixoto I've been thinking about having this in MLX LM recently. Any interest in sending a PR?

It might make to do it after we have a more manageable config (https://github.com/ml-explore/mlx-examples/pull/503) but that should be landed soon!

awni commented 6 months ago

To be more concrete, I'm envisioning you just set the loss in the config. e.g. cross_entropy or dpo

ivanfioravanti commented 5 months ago

This would be an awesome addition to mlx_examples! 🔥

N8python commented 5 months ago

I'm very very excited for this! Don't have the technical expertise to implement the DPO directly but would love to help in other ways (config, code cleanup) if neccessary!

lin72h commented 5 months ago

That makes MLX really useful for production not just a research tool!

kishoretvk commented 5 months ago

+500 waiting for this

developerlin commented 3 months ago

Wait for this, when will the DPO training be supported?