mobeets / sarsa-rnn

0 stars 0 forks source link

reversal learning #13

Open mobeets opened 2 years ago

mobeets commented 2 years ago

The idea is two actions (L vs. R), where choosing that action has different probabilities of reward (p1=0.8 and p2=0.2), and those probabilities flip every 5 trials.

This is almost the same as Schaeffer et al. (2020) (see notes here). Except there you also have to do a perceptual discrimination on each trial. But the point is, they train that task not using RL but using supervised learning, which we could also do here. But either way, it does seem clear that an RNN trained to do this will "discover" how to do belief estimation, regarding things like "did the reward probabilities flip".

So again, one takeaway is, even though this is a "model-free" method, we can still do "hierarchical RL" or "hierarchical inference" in these tasks. Thus, we can't necessarily conclude that "model-based" means having explicit probabilities. They can still be trained in model-free ways, and the "model-based-like" representations will emerge.

The goal of all of this is to form some understanding of what OFC is doing, where the claim in Wilson2014 is that it does belief estimation, but only if it matters for the task. I'll call this a "compressed belief representation." Well, this model-free RNN approach would be a way to implement this compressed belief estimation.

mobeets commented 2 years ago

The issue though is to recapitulate the lesioning results. In a reversal learning task, if you lesion OFC, it looks like the network can no longer switch between the two modes—i.e., it loses the ability to do this slower-timescale estimation of "which block are we in?" In the Schaeffer2020 paper their network does this using a line attractor. So lesioning OFC would be like somehow selectively lesioning the line attractor.

Similarly, in my Value RNN, making a Task 2 network look like a Task 1 network would be the same as causing an anti-pitchfork bifurcation, where you raise the α parameter to give it three fixed points rather than one.

How to actually do this to the RNN is an open question. You'd have to be essentially figuring out which network parameters to change to cause the appropriate bifurcation.