haozheji / exact-optimization

ICML 2024 - Official Repository for EXO: Towards Efficient Exact Optimization of Language Model Alignment
https://arxiv.org/abs/2402.00856
MIT License
45 stars 0 forks source link

The DPO loss implementation seems incomplete #2

Closed peterjc123 closed 5 months ago

peterjc123 commented 5 months ago

It seems that the energy_labels are not set for the samples in DPO, https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/data.py#L63 https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/data.py#L94 but they are used in the loss implementation below. https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/loss.py#L53

peterjc123 commented 5 months ago

For the ExactDataset, the energy_labels are set https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/data.py#L168 and according to the code below they are set to 1 for preferred samples and 0 for the rejected samples https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/utils/data/pref_to_rw.py#L22

But then the logic in computing the DPO loss becomes weird because the loss for rejected samples is zero since their energy_labels_group is zero. https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/loss.py#L53

Did I misread the code?

haozheji commented 5 months ago

The DPODataset is not used in the training code and should be deprecated.

peterjc123 commented 5 months ago

The DPODataset is not used in the training code and should be deprecated.

Thanks for the prompt response. Yeah, it seems that it is not used. But the DPO+pref loss seems weird if the ExactDataset is used, could you please elaborate on that?

haozheji commented 5 months ago

For the ExactDataset, the energy_labels are set

https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/data.py#L168

and according to the code below they are set to 1 for preferred samples and 0 for the rejected samples https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/utils/data/pref_to_rw.py#L22

But then the logic in computing the DPO loss becomes weird because the loss for rejected samples is zero since the energy_labels_group is zero.

https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/loss.py#L53

This is an equivalent implementation of the original DPO loss to adapt with the soft label version. The loss for chosen samples is exactly the original DPO loss.

peterjc123 commented 5 months ago

For the ExactDataset, the energy_labels are set https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/data.py#L168

and according to the code below they are set to 1 for preferred samples and 0 for the rejected samples https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/utils/data/pref_to_rw.py#L22

But then the logic in computing the DPO loss becomes weird because the loss for rejected samples is zero since the energy_labels_group is zero. https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/loss.py#L53

This is an equivalent implementation of the original DPO loss to adapt with the soft label version. The loss for chosen samples is exactly the original DPO loss.

So the DPO loss here is -(F.log_softmax(beta * (policy_chosen_logps - ref_chosen_logps)) + F.log_softmax(policy_rejected_logps - ref_rejected_logps)) / 2 for preference data, right?

haozheji commented 5 months ago

The DPODataset is not used in the training code and should be deprecated.

Thanks for the prompt response. Yeah, it seems that it is not used. But the DPO+pref loss seems weird if the ExactDataset is used, could you please elaborate on that?

The original DPO loss is -F.log_sigmoid(beta * (policy_chosen_logps - ref_chosen_logps) - beta * (policy_rejected_logps - ref_rejected_logps)).

This implementation uses hard labels labels=[1,0] where 1 is for chosen and 0 is for rejected samples and the loss becomes - (labels * F.log_softmax(beta * (policy_logps - ref_logps)).sum(0) where policy_logps = [policy_chosen_logps, policy_rejected_logps] and ref_logps = [ref_chosen_logps, ref_rejected_logps]. Since the label for the rejected sample is 0, the loss is actually F.log_softmax(beta * (policy_logps - ref_logps))[0] which is exactly the original DPO loss.

peterjc123 commented 5 months ago

But according to the code implementation, it is something like below.

logps = [policy_chosen_logps - ref_chosen_logps, policy_rejected_logps - ref_rejected_logps]
loss = - (labels * F.log_softmax(beta * logps)).sum(0)

which eventually becomes

loss = - ([1, 0] * F.log_softmax(beta * [policy_chosen_logps - ref_chosen_logps, policy_rejected_logps - ref_rejected_logps])).sum(0)
= - F.log_softmax(beta * [policy_chosen_logps - ref_chosen_logps, policy_rejected_logps - ref_rejected_logps])[0]

I also debugged over the loss function using some dummy inputs.

import torch
import torch.nn as nn

def dpo_loss(ref_logits, 
            model_logits, 
            attention_mask, 
            y_ids, 
            prompt_lens, 
            energy_labels, 
            N=2, 
            beta=1.0, 
            beta_model=1.0,
            loss_type="dpo"):

    # prepare
    bsz = ref_logits.size(0)
    logsigmoid = nn.LogSigmoid()
    logsm = nn.LogSoftmax(-1)
    device = ref_logits.device

    # comment out unrelated logic
    # for i in range(attention_mask.size(0)):
    #     attention_mask[i, :prompt_lens[i][0]] = 0

    model_logprobs = torch.gather(logsm(model_logits)[:, :-1, :], 2, y_ids[:, 1:].unsqueeze(2)).squeeze(2) * attention_mask[:, 1:]
    ref_logprobs = torch.gather(logsm(ref_logits)[:, :-1, :], 2, y_ids[:, 1:].unsqueeze(2)).squeeze(2) * attention_mask[:, 1:]

    estimated_rewards_prefix = (model_logprobs - ref_logprobs).sum(1, keepdim=True)

    print('logps:')
    print(estimated_rewards_prefix)

    loss = 0.
    count = 0
    for estimated_rewards_prefix_group, energy_labels_group in zip(estimated_rewards_prefix.split(N), 
                                                                    energy_labels.split(N)):

        # prepare label
        if "rw" in loss_type:
            energy_labels_group = (energy_labels_group / beta).softmax(0)

        print('current group logps:')
        print(estimated_rewards_prefix_group)

        # num_contrastive * num_draw
        log_est_rewards_prefix_draw = (beta_model * estimated_rewards_prefix_group).log_softmax(0)

        print('current group logps after softmax:')
        print(log_est_rewards_prefix_draw)

        print('current energy labels group:')
        print(energy_labels_group)

        cur_loss = ( - energy_labels_group * log_est_rewards_prefix_draw ).sum(0).mean()

        print("current loss:")
        print(cur_loss)
        print('-' * 60)

        loss = loss + cur_loss

        count += 1

    return loss / count

if __name__ == '__main__':

    ref_logits = torch.rand(4, 10, 256)
    model_logits = torch.rand(4, 10, 256)
    y_ids = torch.zeros(4, 10, dtype=torch.long)

    attention_mask = torch.ones(4, 10, dtype=torch.long)
    prompt_lens = None

    energy_labels = torch.tensor([1, 0, 1, 0]).view(4, 1)

    dpo_loss(ref_logits, 
                model_logits, 
                attention_mask, 
                y_ids, 
                prompt_lens, 
                energy_labels)

And the output of the code is shown below.

logps:
tensor([[-0.2783],
        [ 0.3665],
        [-0.0128],
        [-0.9451]])
current group logps:
tensor([[-0.2783],
        [ 0.3665]])
current group logps after softmax:
tensor([[-1.0667],
        [-0.4218]])
current energy labels group:
tensor([[1],
        [0]])
current loss:
tensor(1.0667)
------------------------------------------------------------
current group logps:
tensor([[-0.0128],
        [-0.9451]])
current group logps after softmax:
tensor([[-0.3319],
        [-1.2643]])
current energy labels group:
tensor([[1],
        [0]])
current loss:
tensor(0.3319)

As you can see, the second item is totally ignored.

peterjc123 commented 5 months ago

The code is okay. It seems that we can have F.log_softmax(logps)[0] = F.logsigmoid(logps[0]-logps[1]). Thanks for the detailed explanation.

P.S. F.log_softmax(logps)[0] = log(exp(logps[0])/ (exp(logps[0]) + exp(logps[1])))) F.logsigmoid(logps[0]-logps[1]) = log(1/ (1 + exp(logps[1] - logps[0]))) So apparently, they are equal.