mobeets / q-rnn

0 stars 0 forks source link

Beron with KL penalty #20

Open mobeets opened 7 months ago

mobeets commented 7 months ago

todo: add KL penalty between current and marginal policy as an intrinsic reward/penalty

log π(a|s)/p(a)

the question is if this will induce perseveration the only thing to figure out is how to keep track of p(a)

mobeets commented 7 months ago

this involves using a softmax policy I believe. so the first step is to confirm that we can still get a good model using a softmax policy (without penalty)

updates (as of Feb 16, 1:34 PM):

a few observations/questions:

mobeets commented 7 months ago

relevant reading: Lucy/Sam's "Human decision making balances reward maximization and policy compression"

one thought is that celia observes RTs increasing on switch trials. this could be consistent with policy compression where they assume RT is linearly related to the current policy's entropy, which should be higher on switch trials.

also relevant is that they have updates for the tradeoff parameter, which i've currently fixed to just be kl_penalty=1. (note that their β=1/kl_penalty.

mobeets commented 6 months ago

fwiw, the only policy-gradient algorithm in rllib that can also support RNNs seems to be A3C (actor-critic). this includes an entropy penalty, which is the same as the policy compression (mutual information) objective when the optimal marginal policy is uniform. (Update: "More generally, RLlib supports the use of recurrent/attention models for all its policy-gradient algorithms (A3C, PPO, PG, IMPALA)")

but i guess in this task, the idea is that we would want the marginal policy to be more local (e.g., the marginal policy within the last K trials, where K is small), so that there is perseveration. so in this case, we couldn't get away with a maximum entropy penalty alone.

mobeets commented 6 months ago

one idea: what if we didn't train with KL, but added it into the policy at test time? just to see whether we can even induce perseveration.

Softmax model, τ=0.001, no KL:

(Note that increasing τ is the same as adding a KL penalty but for a fixed, uniform marginal policy.)

mobeets commented 6 months ago

To do next

I think the thing to do next is to take a series of Q from rolling out a standard RNN, and then play around with how both the kl_coeff (β) and kl_alpha (α) terms would change the policy.

So let (Q, τ) be the params used to do the rollout. Given that we can get π. Then for a given (β, α), we can compute π', which is what the policy would have been given (β, α). Then plot this during a series of trials.

image

Okay so here's the best sort of tuned example I could find, at a period where the actions switch. I'm showing here the late-session action prefs (Q/τ) in solid lines, and the modified action prefs (Q/τ + β P, where P is the marginal policy) in dashed lines. What I learned:

My take-away is that, as I suggested in the below comment, we would need to not treat "null" actions as actions at all, from the policy compression perspective. so that the marginal policy just applies to the non-null actions. which is maybe consistent with positing a hierarchical policy that first decides act vs. no-act, and then within the act choice, we have a sub-policy, and that is where the policy compression plays out.

mobeets commented 6 months ago

i think one reason the marginal policy approach might not work to create perseveration is because the null action is encoded like a real action. whereas what we want is for the last DECISION to be more likely to be taken. it would be cool if we had some formal way to have our policy choose to take an action vs no action at all. like in a semi-mdp. but i'm not sure whether that's a thing.

this reminds me of the fact that in a go/nogo task, nogo isn't really an action but the withholding of an action. (eg, a nogo action has no RT, whereas any other action has an RT.)

maybe the way to handle this is that the marginal policy is only calculated over non-null actions, the similarly the KL penalty isn't applied when there is no action taken.

but yeah this problem is way more general. like, take a standard epsilon greedy or softmax policy. then increasing epsilon or tau is going to increase the probability of abort trials, because we're now equally likely to sample a choice action at any given time. it's like there's some basic learned hierarchical policy whose only actions are "act or do not act", and the "act" policy then has sub choices about which precise action to take. and when we "explore," we want to explore only within the "act" sub policy.

(edit: for this last point, yes this is true, but in terms of what applying a KL penalty would do, the model should still be able to basically learn the times when it can afford to be more noisy—e.g., at decision points—while still maintaining low variability in the middle of the trial—e.g., to avoid aborting the trial.)

mobeets commented 6 months ago

to do

I think one useful first step could be to just take a simple, fixed representation (like, a belief representation), and add in a KL penalty, and use lucy's code to apply a policy compression type learning. then we can see in general if this approach is going to achieve what we think it is