eric-mitchell / direct-preference-optimization

Reference implementation for DPO (Direct Preference Optimization)
Apache License 2.0
2.18k stars 180 forks source link

Question about _get_batch_logps of trainers.py #57

Closed wulaoshi closed 10 months ago

wulaoshi commented 11 months ago
def _get_batch_logps(logits: torch.FloatTensor, labels: torch.LongTensor, average_log_prob: bool = False) -> torch.FloatTensor:
    assert logits.shape[:-1] == labels.shape

    labels = labels[:, 1:].clone()
    logits = logits[:, :-1, :]
    loss_mask = (labels != -100)

    # dummy token; we'll ignore the losses on these tokens later
    labels[labels == -100] = 0

    per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)

    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)      # Why return the sum of the log probabilities of the (non-masked) tokens?

First of all, I'd like to express my gratitude for the amazing work done by the author. While going through the DPO code, I came across a point of confusion: I understand that π(y|x) represents the probability distribution when the model generates a response, but in the code implementation, there's an additional sum(-1) operation. However, I don't see summation operation in the formula, as shown in the image below:

image

Could you please help me understand the logic behind this implementation? Thank you!

eric-mitchell commented 10 months ago

Thanks for the kind words. In the equation you posted, the probability is the probability of the entire response sequence, conditioned on the input. The per-token logprobs need to be summed (along the sequence length dimension) to get the total log probability of the chosen/rejected sequences. The sum is implicit in the equation (i.e., the total log probabilities are the sum of the per-token log probabilities).

Feel free to re-open if this didn't answer your question!

longbowzhang commented 10 months ago

Hi @eric-mitchell, Should the variable length of the chose/rejected sequences be taken into account in the loss? Any comments on this is highly appreciated.

Gryff1ndor commented 6 months ago

Hi @eric-mitchell , In your formula (the image below), it seems that the log[π(y|x)] was calculate through .sum(-1) after logits.softmax(-1), then .log(). image But in your codes (the image below), the log[π(y|x)] was calculate through .sum(-1) after logits.log_softmax(-1). image

the two ways to calculate log[π(y|x)] seem different.Could you please tell me if they conflict each other?