wzhouad / WPO

Code and models for paper "WPO: Enhancing RLHF with Weighted Preference Optimization"
Other
8 stars 0 forks source link

How to calculate equation 2 efficiently? #1

Open peterjc123 opened 1 month ago

peterjc123 commented 1 month ago

As can be seen in the paper, image we need to calculate sum(pi_theta(v|x,y(<t))^2), but we only have log(pi_theta(v|x,y(<t))).

wzhouad commented 1 month ago

My implementation is to first calculate p=pi_theta(v|x,y(<t)) by F.softmax. Then calculate sum(pi_theta(v|x,y(<t))^2) with (p ** 2).sum(-1).

peterjc123 commented 3 weeks ago

@wzhouad Thanks for your prompt answer. I have another question. As for sampled alignment, do you actually generate a new token or a new sequence based on sampling? e.g. We have a sequence Q:[QQQQ]A:[AAAAA]. Then during training we could get something like

Q:[QQQQ]A: -> Q:[QQQQ]A:[S]
Q:[QQQQ]A:[A] -> Q:[QQQQ]A:[AT]
Q:[QQQQ]A:[AA] -> Q:[QQQQ]A:[AAQ]
Q:[QQQQ]A:[AA] -> Q:[QQQQ]A:[AAAW]
  1. If we generate a sequence, then it should be something like
    Q:[QQQQ]A: -> Q:[QQQQ]A:[S]
    Q:[QQQQ]A:[S] -> Q:[QQQQ]A:[ST]
    Q:[QQQQ]A:[ST] -> Q:[QQQQ]A:[STQ]
    Q:[QQQQ]A:[STQ] -> Q:[QQQQ]A:[STQW]

    I suppose it should be the former, otherwise it will be very time consuming I guess.

wzhouad commented 3 weeks ago

No, we do not generate new sequence in training time.

peterjc123 commented 2 weeks ago

Below is my implementation. Do you think it is correct?

# Assume we have `logits` of shape [B, S, V] and for each b in B, if `idx(b) % 2 == 0`, it's for positive samples, otherwise, it's for negative samples. BTW, we have `loss_mask` and `labels` of shape [B, S].

probs = logits.softmax(-1)
per_token_probs = torch.gather(probs, dim=2, index=labels.unsqueeze(2)).squeeze(2)

policy_weights = torch.exp((loss_mask * torch.log(per_token_probs / (probs ** 2).sum(-1))).sum(-1) / loss_mask.sum(-1))

policy_preferred_weight, policy_rejected_weight = policy_weights.view(-1, 2).unbind(1)
wpo_weight = (policy_preferred_weight * policy_rejected_weight).detach()

Also, when I tried to use WPO with KTO(beta=0.01), the KTO loss grows from 0.36 to 0.47 and then gradually go down. The rewards for positive and negative samples both show a crazy 5-10x decline. It works fine with KTO(beta=0.1), though.