xfactlab / orpo

Official repository for ORPO
Apache License 2.0
421 stars 39 forks source link

attention mask in compute_logps function #17

Closed hjc3613 closed 7 months ago

hjc3613 commented 7 months ago

why using mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:]? actually, the chosen_attention_mask like [0 0 0 0 1 1 1 1] where 0 correspond pad tokens,and 1 correspond prompt+answer tokens? but prompt_attention_mask like [0 0 0 0 0 0 1 1] where 0 correspond pad tokens, and 1 correspond prompt tokens? so mask = [0 0 0 0 1 1 0 0], can not mask the prompt tokens

jiwooya1000 commented 7 months ago

Hello @hjc3613, With the same reasons mentioned in here, ORPOTrainer pads the inputs to the right side by default as DPOTrainer.

I think your examples are padding the inputs to the left, which is not the case in our training code.

By padding to the right side, it would look like:

chosen_attention_mask = [1, 1, 1, 1, 0, 0, 0, 0]
prompt_attention_mask = [1, 1, 0, 0, 0, 0, 0, 0]

and it will mask out the overlapping prompt tokens.

jiwooya1000 commented 7 months ago

Sorry @hjc3613, mistakenly closed it😅 Let me know if you have further questions

hjc3613 commented 7 months ago

Hello @hjc3613, With the same reasons mentioned in here, ORPOTrainer pads the inputs to the right side by default as DPOTrainer.

I think your examples are padding the inputs to the left, which is not the case in our training code.

By padding to the right side, it would look like:

chosen_attention_mask = [1, 1, 1, 1, 0, 0, 0, 0]
prompt_attention_mask = [1, 1, 0, 0, 0, 0, 0, 0]

and it will mask out the overlapping prompt tokens.

got it, thank you~