xfactlab / orpo

Official repository for ORPO
Apache License 2.0
373 stars 34 forks source link

compute_logps function, why does it return also prob for the last token of answer #35

Open javismiles opened 6 days ago

javismiles commented 6 days ago

good day, In the compute_logps function, if you are returning a prob for the token before the answer because we begin predicting from the token before the answer begins, then why are we returning also a prob for the last token of the answer, as that one should not matter as the next token would be outside of the answer

thank you for any tips

def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits): mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:] per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2) return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64)

javismiles commented 6 days ago

basically it makes sense that you shift to the left by 1 the prompt mask so that you capture the prediction from the token before the answer begins, but you are also including the very last token of the answer, which should not be considered because it has nothing to predict, as after that the answer finished, thats the doubt

javismiles commented 6 days ago

"mask.sum will only sum the active elements of mask so we normalize by the total tokens of answer" theoretically mask.sum should only add the active elements of the mask, but doing debugging with toy examples and this code, it picks also the one after the answer, which is what confuses me

javismiles commented 6 days ago

to summarize, you seem to be taking the average of the probs of prediction for token before the answer plus all tokens of answer including the last token of answer (and thats the doubt, why do you also count the prediction from last token of answer which shouldnt matter as the first token after the answer is not relevant)