Closed peterjc123 closed 5 months ago
For the ExactDataset, the energy_labels
are set
https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/data.py#L168
and according to the code below they are set to 1 for preferred samples and 0 for the rejected samples
https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/utils/data/pref_to_rw.py#L22
But then the logic in computing the DPO loss becomes weird because the loss for rejected samples is zero since their energy_labels_group
is zero.
https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/loss.py#L53
Did I misread the code?
The DPODataset
is not used in the training code and should be deprecated.
The
DPODataset
is not used in the training code and should be deprecated.
Thanks for the prompt response. Yeah, it seems that it is not used. But the DPO+pref loss seems weird if the ExactDataset is used, could you please elaborate on that?
For the ExactDataset, the
energy_labels
are setand according to the code below they are set to 1 for preferred samples and 0 for the rejected samples https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/utils/data/pref_to_rw.py#L22
But then the logic in computing the DPO loss becomes weird because the loss for rejected samples is zero since the
energy_labels_group
is zero.
This is an equivalent implementation of the original DPO loss to adapt with the soft label version. The loss for chosen samples is exactly the original DPO loss.
For the ExactDataset, the
energy_labels
are set https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/data.py#L168and according to the code below they are set to 1 for preferred samples and 0 for the rejected samples https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/utils/data/pref_to_rw.py#L22
But then the logic in computing the DPO loss becomes weird because the loss for rejected samples is zero since the
energy_labels_group
is zero. https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/loss.py#L53This is an equivalent implementation of the original DPO loss to adapt with the soft label version. The loss for chosen samples is exactly the original DPO loss.
So the DPO loss here is -(F.log_softmax(beta * (policy_chosen_logps - ref_chosen_logps)) + F.log_softmax(policy_rejected_logps - ref_rejected_logps)) / 2
for preference data, right?
The
DPODataset
is not used in the training code and should be deprecated.Thanks for the prompt response. Yeah, it seems that it is not used. But the DPO+pref loss seems weird if the ExactDataset is used, could you please elaborate on that?
The original DPO loss is -F.log_sigmoid(beta * (policy_chosen_logps - ref_chosen_logps) - beta * (policy_rejected_logps - ref_rejected_logps))
.
This implementation uses hard labels labels=[1,0]
where 1 is for chosen and 0 is for rejected samples and the loss becomes - (labels * F.log_softmax(beta * (policy_logps - ref_logps)).sum(0)
where policy_logps = [policy_chosen_logps, policy_rejected_logps]
and ref_logps = [ref_chosen_logps, ref_rejected_logps]
. Since the label for the rejected sample is 0, the loss is actually F.log_softmax(beta * (policy_logps - ref_logps))[0]
which is exactly the original DPO loss.
But according to the code implementation, it is something like below.
logps = [policy_chosen_logps - ref_chosen_logps, policy_rejected_logps - ref_rejected_logps]
loss = - (labels * F.log_softmax(beta * logps)).sum(0)
which eventually becomes
loss = - ([1, 0] * F.log_softmax(beta * [policy_chosen_logps - ref_chosen_logps, policy_rejected_logps - ref_rejected_logps])).sum(0)
= - F.log_softmax(beta * [policy_chosen_logps - ref_chosen_logps, policy_rejected_logps - ref_rejected_logps])[0]
I also debugged over the loss function using some dummy inputs.
import torch
import torch.nn as nn
def dpo_loss(ref_logits,
model_logits,
attention_mask,
y_ids,
prompt_lens,
energy_labels,
N=2,
beta=1.0,
beta_model=1.0,
loss_type="dpo"):
# prepare
bsz = ref_logits.size(0)
logsigmoid = nn.LogSigmoid()
logsm = nn.LogSoftmax(-1)
device = ref_logits.device
# comment out unrelated logic
# for i in range(attention_mask.size(0)):
# attention_mask[i, :prompt_lens[i][0]] = 0
model_logprobs = torch.gather(logsm(model_logits)[:, :-1, :], 2, y_ids[:, 1:].unsqueeze(2)).squeeze(2) * attention_mask[:, 1:]
ref_logprobs = torch.gather(logsm(ref_logits)[:, :-1, :], 2, y_ids[:, 1:].unsqueeze(2)).squeeze(2) * attention_mask[:, 1:]
estimated_rewards_prefix = (model_logprobs - ref_logprobs).sum(1, keepdim=True)
print('logps:')
print(estimated_rewards_prefix)
loss = 0.
count = 0
for estimated_rewards_prefix_group, energy_labels_group in zip(estimated_rewards_prefix.split(N),
energy_labels.split(N)):
# prepare label
if "rw" in loss_type:
energy_labels_group = (energy_labels_group / beta).softmax(0)
print('current group logps:')
print(estimated_rewards_prefix_group)
# num_contrastive * num_draw
log_est_rewards_prefix_draw = (beta_model * estimated_rewards_prefix_group).log_softmax(0)
print('current group logps after softmax:')
print(log_est_rewards_prefix_draw)
print('current energy labels group:')
print(energy_labels_group)
cur_loss = ( - energy_labels_group * log_est_rewards_prefix_draw ).sum(0).mean()
print("current loss:")
print(cur_loss)
print('-' * 60)
loss = loss + cur_loss
count += 1
return loss / count
if __name__ == '__main__':
ref_logits = torch.rand(4, 10, 256)
model_logits = torch.rand(4, 10, 256)
y_ids = torch.zeros(4, 10, dtype=torch.long)
attention_mask = torch.ones(4, 10, dtype=torch.long)
prompt_lens = None
energy_labels = torch.tensor([1, 0, 1, 0]).view(4, 1)
dpo_loss(ref_logits,
model_logits,
attention_mask,
y_ids,
prompt_lens,
energy_labels)
And the output of the code is shown below.
logps:
tensor([[-0.2783],
[ 0.3665],
[-0.0128],
[-0.9451]])
current group logps:
tensor([[-0.2783],
[ 0.3665]])
current group logps after softmax:
tensor([[-1.0667],
[-0.4218]])
current energy labels group:
tensor([[1],
[0]])
current loss:
tensor(1.0667)
------------------------------------------------------------
current group logps:
tensor([[-0.0128],
[-0.9451]])
current group logps after softmax:
tensor([[-0.3319],
[-1.2643]])
current energy labels group:
tensor([[1],
[0]])
current loss:
tensor(0.3319)
As you can see, the second item is totally ignored.
The code is okay. It seems that we can have F.log_softmax(logps)[0] = F.logsigmoid(logps[0]-logps[1])
. Thanks for the detailed explanation.
P.S.
F.log_softmax(logps)[0] = log(exp(logps[0])/ (exp(logps[0]) + exp(logps[1]))))
F.logsigmoid(logps[0]-logps[1]) = log(1/ (1 + exp(logps[1] - logps[0])))
So apparently, they are equal.
It seems that the
energy_labels
are not set for the samples in DPO, https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/data.py#L63 https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/data.py#L94 but they are used in the loss implementation below. https://github.com/haozheji/exact-optimization/blob/7d56003e96f389284f707a105c2a9ffc162cd07b/src/align_stage/loss.py#L53