mingkaid / rl-prompt

Accompanying repo for the RLPrompt paper
MIT License
286 stars 52 forks source link

RL-prompt MLP loss #25

Closed hv68 closed 1 year ago

hv68 commented 1 year ago

Hey, I'm a student trying to replicate the sentiment analysis aspect of RLPrompt, and was wondering how exactly the MLP in the LM is trained? I see that you're getting a ground truth from the teacher_forcing method, so is it basically taking the previous step's prompt that got the best reward, calculating a loss based on current prompt's reward and previous best prompt's reward, then backproping on it?

Another clarification that would be helpful is when formulating this as a RL problem, are the states essentially the time steps themselves and the prompts being the action that best rewards our problem? In that case, are you updating time step t's Q value based on t+1's generated prompts and their rewards? I'm a little confused on this because in this case the actions you can take grows every time step by a factor of how many prompts you generate per time step?

mingkaid commented 1 year ago

Hi, thank you for your interest! Please find the answers below:

How is the MLP in the LM trained? What's the role of the `teacher_forcing` function?

We train the MLP using the Soft Q-Learning (SQL) algorithm based on this paper. The algorithm trains the MLP using a regression loss for predicting the reward you will receive for selecting each prompt token. The loss usually uses a target network to provide the regression target. This regression target is computed in the teacher_forcing function. For more details, please refer to the paper linked above.

Definition of state and action

Following previous work on RL for text generation, we define the state as previous tokens, and the action as the next token to select from a vocabulary.

I hope these answer your questions. Please let us know if you have other questions. I am closing this issue now because it is a clarification question.