mobeets / q-rnn

0 stars 0 forks source link

Why does Beron2022 have 4 fixed points? #10

Closed mobeets closed 1 year ago

mobeets commented 1 year ago

Beliefs have two fixed points. The RNNs consistently have four. For example, look at this H=10 network:

So the RNN has four fixed points even though its output, Q, aligns these fixed points so that they are essentially the same. So it is somehow easier for the RNN to keep these fixed points separate? Possibly because the fixed points involve inferring the absence of reward inputs?

mobeets commented 1 year ago

I trained an H=2 network with an L2 penalty and still see 3-4 fixed points. Penalty was 0.01.

This is almost even more unexpected to me: By penalizing the L2 norm on Z, it seems to only compress the Q readout (i.e., even more than before), but not the Z activity. Which makes me think that the extra fixed points are somehow important. Or at least, really encouraged.

mobeets commented 1 year ago

Here's one line of thought: Let A be a rewarded 'a=1' choice, b be an unrewarded 'a=0' choice, etc. Then the inputs to the RNN are:

Then we want A/b to have the same fixed points; but their inputs are exactly opposite. We also want B/a to have the same fixed points; and again their inputs are exactly opposite. If the inputs were shared, we could, say, have the fixed point determined by that shared input. But since they are totally unlinked, this might be harder?

This reminds me of something that came up with Nora's project: in the backward blocking task, you won't update the cue B representation when only cue A is shown. Similarly, when "A" is shown vs. when "b" is shown, the input representations are non-overlapping. So you won't update your representation for "b" when "A" is shown, and vice versa.

But the recurrent weights are what sets the fixed points. And these are shared regardless of the inputs.

mobeets commented 1 year ago

Here's another way of thinking about it: The model has two inputs: the previous reward and action. Let's suppose we have an "A" trial, so (r=1,a=1), and that we're at the A fixed point. Now, consider what happens if the input changes. If r changes, that's because the environment changed state. But we'd have the same action. So we'd get an "a" trial. (We'll assume p_rew_max=1.0 for simplicity.) Whereas if a changes, that's because our action changed (e.g., due to ε exploration). So that's a "b" trial.

The thing I'm wondering is, which of these happens more often during training might influence the decay rates of our fixed points. E.g., if we have more frequent A -> b transitions vs. A -> a transitions.

mobeets commented 1 year ago

FWIW, this task is simple enough that you can use a purely random behavior policy during training, and still get the same results. Below is an H=2 RNN where during training the actions were just randomly selected (recall that Q learning is off-policy). So here this tells us that the four fixed points are not a function of the behavioral policy during training.

mobeets commented 1 year ago

It could have to do with the statistics of inputs though. For example, transitioning from A->B seems like it must always be rarer than transitioning A->a or even A->b, basically no matter what?

mobeets commented 1 year ago

By the way, I think the takeaway here is that the RNN defaults to having a fixed point for each unique input, even if fewer fixed points would do.

But this differs from my approach in the belief paper, where I only ever looked at fixed points conditional on NO inputs.