mingkaid / rl-prompt

Accompanying repo for the RLPrompt paper
MIT License
295 stars 54 forks source link

Clarification on the RL problem #28

Closed hv68 closed 1 year ago

hv68 commented 1 year ago

Hey! Thanks for the previous clarification. I'm trying to implement this with just regular Q learning (I know that the paper suggests soft-q, i wanted to compare both approaches for a paper), and even after 1k steps the accuracy on a classification problem doesn't seem to have any sign of actually learning, so I wanted to confirm whether the approach is right.

I've understood the problem as 1) initialize a single token of a prompt 2) generate prompt of size T (Z_1) 3) find reward 4) using the last token generated for the prompt Z_1, generate a new prompt Z using the target network 5) backprop on the Q_1 values found for Z_1 with the Q values found for Z in the target and the reward (mse between Q_1 and reward+discountQ as the loss). In this case, the reward is scalar and broadcasts on the discounted Q value 6) using the last token in Z_1, find Z_2 and repeat steps 3-6 again

Does this sound right? https://colab.research.google.com/drive/1fs9lILaBEqJs9ieF2lnH8fgwZSidJcw0?authuser=1#scrollTo=6mRz6CFFbVzV this is the first version of the code i've implemented, and even with a decaying epsilon, I tend to find prompts that converge to a single word repeated T times, and the max accuracy I get is .675. I'm using Distilbert as my masked language model instead of Roberta due to memory and hardware constraints. Any help would be greatly appreciated. Thank you!

mingkaid commented 1 year ago

Hi, sorry for the delayed response. I hope you were able to figure it out.

For the regular Q-learning, did you implement any exploration mechanism? From how it sounds now, you were choosing the token with the maximum Q value every time? Typically when people do Q-learning, they'll implement an "epsilon-exploration" mechanism, where at each step there is epsilon probability they will choose some random action as opposed to acting according to the Q value. In our case, the action is the token.

I am closing this issue now because it's a clarification question