Open peterjc123 opened 5 months ago
My implementation is to first calculate p=pi_theta(v|x,y(<t))
by F.softmax
. Then calculate sum(pi_theta(v|x,y(<t))^2)
with (p ** 2).sum(-1)
.
@wzhouad Thanks for your prompt answer. I have another question. As for sampled alignment, do you actually generate a new token or a new sequence based on sampling?
e.g. We have a sequence Q:[QQQQ]A:[AAAAA]
. Then during training we could get something like
Q:[QQQQ]A: -> Q:[QQQQ]A:[S]
Q:[QQQQ]A:[A] -> Q:[QQQQ]A:[AT]
Q:[QQQQ]A:[AA] -> Q:[QQQQ]A:[AAQ]
Q:[QQQQ]A:[AA] -> Q:[QQQQ]A:[AAAW]
Q:[QQQQ]A: -> Q:[QQQQ]A:[S]
Q:[QQQQ]A:[S] -> Q:[QQQQ]A:[ST]
Q:[QQQQ]A:[ST] -> Q:[QQQQ]A:[STQ]
Q:[QQQQ]A:[STQ] -> Q:[QQQQ]A:[STQW]
I suppose it should be the former, otherwise it will be very time consuming I guess.
No, we do not generate new sequence in training time.
Below is my implementation. Do you think it is correct?
# Assume we have `logits` of shape [B, S, V] and for each b in B, if `idx(b) % 2 == 0`, it's for positive samples, otherwise, it's for negative samples. BTW, we have `loss_mask` and `labels` of shape [B, S].
probs = logits.softmax(-1)
per_token_probs = torch.gather(probs, dim=2, index=labels.unsqueeze(2)).squeeze(2)
policy_weights = torch.exp((loss_mask * torch.log(per_token_probs / (probs ** 2).sum(-1))).sum(-1) / loss_mask.sum(-1))
policy_preferred_weight, policy_rejected_weight = policy_weights.view(-1, 2).unbind(1)
wpo_weight = (policy_preferred_weight * policy_rejected_weight).detach()
Also, when I tried to use WPO with KTO(beta=0.01), the KTO loss grows from 0.36 to 0.47 and then gradually go down. The rewards for positive and negative samples both show a crazy 5-10x decline. It works fine with KTO(beta=0.1), though.
As can be seen in the paper, we need to calculate
sum(pi_theta(v|x,y(<t))^2)
, but we only havelog(pi_theta(v|x,y(<t)))
.