mobeets / q-rnn

0 stars 0 forks source link

Beron2022 reparameterization #11

Closed mobeets closed 1 year ago

mobeets commented 1 year ago

Hypothesis: if i change the input encoding, I can get two fixed points instead of four. The input encoding would be something like (a=1 and r=1), (a=0 and r=0), and then the other two. Basically A,a,B,b. Or even just (A or b),(B or a). My guess is that these will only have two fps

mobeets commented 1 year ago

Okay so training an H=2 (below) or H=3 RNN using one-hot inputs (A,a,B,B) still gives 4 fixed points in Z, but roughly 2 fixed points in Q:

(By the way, I realized that if I encoded the inputs as (A or b, B or a), then OF COURSE there could be at most two fixed points, since there's only two different possible inputs.)

So my takeaway is that the input encoding is NOT the reason we are seeing four fixed points in this task.

mobeets commented 1 year ago

Elman RNN

I trained a H=2 vanilla RNN, and it also has 4 fixed points, but like, immediate fixed points:

In the first and third plots, you can see the initial seeds I'm using for finding the FPs (basically, the empirical data plus added gaussian noise), and in the third plot you can see the first step of the RNN for a constant "A" input. So basically, as soon as a given input is received, we move very close to that input's fixed point.

So from the Q values, it's as if an "A" input signals a more confident "left" action than a "b" input, and similarly for B/a. I'm now starting to wonder if this is somehow true in the actual experiment, as I've implemented it?

mobeets commented 1 year ago

Aha! I think it's possible that, if we just received no reward (a or b input), our total future return is lower than if we're already receiving reward. Like, if we got a switch trial, it could be either because of a true switch, or an omission, so there's some chance that our next action is wrong (e.g., if we switch when we shouldn't have, or vice versa).

(I at first thought gamma might matter here, but I trained a model with γ=0, and still got four fixed points.)

So I think it does actually make sense to have four fixed points, because a r=0 trial means there's now a chance that either i) our policy will switch when the state didn't (i.e., the r=0 was an omission), or ii) our policy will not switch when it should (i.e., r=0 signaled a state change).

If this is true, then setting p_rew_max=1 should mean only 1 fixed point?

Ack. Nope. 3 fixed points, for some reason.

mobeets commented 1 year ago

Okay, so I think I finally understand why our RNN has more fixed points than beliefs: The Q network has to estimate value AND the correct action. And it turns out that, when p_reward_max=1, the optimal policy is win-stay/lose-switch. Given one of the below inputs, the accuracy (i.e., value) is:

So observing an A or b means you want to choose "L", but these have different values. Therefore, they need to be separate fixed points. So we need four fixed points and not two!

But wait, the above is not actually right. For the optimal policy, the accuracy of seeing "a" should still be 98%. But it may be that, in the training data, we never see two switches one right after the other.

mobeets commented 1 year ago

Aha, okay I'm realizing now that it actually might be reasonable for beliefs to have four belief fixed points! I.e., there aren't really four fixed points, but it seems like it. Suppose p_reward_max=1 for simplicity.

mobeets commented 1 year ago

Okay one more thought. When we have p_reward_max=1, we do not need an RNN, because then (a,r) always signals what is the current underlying state. (The state can still switch before our next action, but the point is that there's nothing to integrate.) So here, the optimal policy can be defined purely in terms of (aprev,rprev).

So we can use a tabular method here. We have 4 different observations, and 2 different actions, so Q is just 4x2. Imagine each row of Q as a point in 2D Q space, so that this point is our "representation." Now note that Q is a function only of the current input, so in this sense, each row of Q is a fixed point: If we show an input over and over again, we get that same representation over and over again. So we have 4 fixed points because we have 4 inputs. And due to randomness in our sampled trajectories, it is unlikely that we will have some of the rows of Q being identical. So we will likely always have 4 fixed points. The only way to not would be to have some sort of penalty.

You can imagine that even in the tabular case, if we have Q = ZW for some Z = 4xK representation, the Z's have even more freedom to be different in K-dim space. Q will bring the representations closer together in Q space, but Z is even less constrained.

mobeets commented 1 year ago

To wrap this up, note that in the belief paper, I only ever looked at the number of fixed points conditioned on NO inputs. There I saw a correspondence between number of fixed points and the task. But here, because we're trial level, I can only look at input-conditioned fixed points. And I think what I'm seeing here is that the RNN does not have the ability to realize it can shared fixed points for different inputs. Those different inputs effectively create different dynamics, and so there's no reason for them to have to share fixed points. Whereas, when there are zero inputs, we have to share the same dynamical system.

So one prediction is that, if we inserted some zeros in between trials, we might see that there are two fixed points.