allenai / RL4LMs

A modular RL library to fine-tune language models to human preferences
https://rl4lms.apps.allenai.org/
Apache License 2.0
2.13k stars 191 forks source link

passing extra variable to the forward function #26

Open lovodkin93 opened 1 year ago

lovodkin93 commented 1 year ago

Hey, I am currently using your repo to finetune a Longformer model. The problem is this model requires to pre-define a global attention mask (in addition to the regular attention mask), which defines which of the tokens get an extra "global attention head". So my question is - is there an easy way to pass this variable, that does not require to skim through the code and locate every calling of the forward functions? I other words- is there an easy way to pass extra model_kwargs? Thanks!

rajcscw commented 1 year ago

Hey, there is no straightforward way to do this. Just adapt the policy implementation to pass these extra arguments.