infer-actively / pymdp

A Python implementation of active inference for Markov Decision Processes
MIT License
420 stars 83 forks source link

State inference: Considerable performance difference with MATLAB implementation? #82

Open SamGijsen opened 2 years ago

SamGijsen commented 2 years ago

Hi all,

First of, many thanks for your work, it looks very promising!

I was comparing some simple agents between the current and matlab implementation and noticed that in terms of reversal learning, the pymdp version appears to adapt considerably slower. I've played around with a variety of setups and hyperparameters but the difference is quite significant.

Example setup: slot machine task without hints 2 actions: 'button' 0 and 'button' 1 A 'button 0 is better' context and a 'button 1 is better' context.

40 trials, with the hidden context switching after 20 trials. Here I plot the state posterior (black) of 100 agents starting with flat context priors, compared to the true but unknown state/context (red). Below I'll include the pymdp code. I'm assuming I'm using the package wrong, and would love to know my misunderstanding.

image image

import pymdp
from pymdp import utils
from pymdp.agent import Agent
import numpy as np
import matplotlib.pyplot as plt

num_obs = [3, 3] # 3 Rewards, 3 choice observations
num_states = [3, 2] # 3 choice states, 2 hidden states
num_factors = len(num_states)

# Press one of two buttons
num_controls = [2, 1] 

A_shapes = [[o_dim] + num_states for o_dim in num_obs]

# initialize the A array to all 0's
A = utils.obj_array_zeros(A_shapes)

# reward probabilities
r1=0.9
r2=0.9

# Reward observations
# STATE 1     Start a0 a1
A[0][0,:,0] = [0, 1-r1, r2  ] # Positive reward
A[0][1,:,0] = [0, r1  , 1-r2] # Negative reward
A[0][2,:,0] = [1, 0   , 0   ] # Neutral (start state)
# STATE 2     Start a0 a1
A[0][0,:,1] = [0, r1  , 1-r2] # Positive
A[0][1,:,1] = [0, 1-r1, r2  ] # Negative
A[0][2,:,1] = [1, 0   , 0   ] # Neutral (start state)

# No uncertainty about choice observations
A[1][:,:,0] = np.eye(num_obs[1])
A[1][:,:,1] = np.eye(num_obs[1])

B_shapes = [[s_dim, s_dim, num_controls[f]] for f, s_dim in enumerate(num_states)]

B = utils.obj_array_zeros(B_shapes)

for i in range(2):
    B[0][0,:,i] = np.ones(3)

B[0][:,0,0] = [0, 1, 0] # action 0: Start  -> a0
B[0][:,0,1] = [0, 0, 1]  # action 1: Start  -> a1

B[1][:,:,0] = np.eye(num_states[1])

C = utils.obj_array_zeros(num_obs)
C[0] = np.array([1, -1, 0]) # Prefer rewards

D = utils.obj_array_uniform(num_states)
D[0] = np.array([1, 0, 0]) # Start in the 'start'-state

# ENVIRONMENT
class my_env():

    def __init__(self, A):

        self.A = A

    def step(self, action, state):

        obs = utils.sample(self.A[0][:, action[0].astype(int)+1, state])

        return [obs, action[0].astype(int)+1]

# SIMULATIONS
T = 40
alpha = 16
gamma = 16
AS = "deterministic"

for run in range(100):
    D2 = D.copy()

    model = Agent(A=A, B=B, C=C, D=D, policy_len=1, action_selection=AS)
    switches = [20,40,50]
    state = 0
    states = []
    pstate = []
    pact = []
    e = my_env(A)

    model.infer_states([2,0]) # 2 = neutral obs, 0 = start state

    for t in range(T):
#         if t > 0: 
#             D2[1] = model.qs[1]
#             model = Agent(A=A, B=B, C=C, D=D2, policy_len=1, action_selection=AS)

        if t in switches:
            state = 1 - state
        states.append(state)

        # Start position for the trial (I believe you don't use this in the tutorial, but it doesnt seem to matter much)
        model.infer_states([2,0]) # 2 = neutral reward, 0 = start state observation

        q_pi, neg_efe = model.infer_policies()

        action = model.sample_action()

        obs = e.step(action, state=state)
        model.infer_states(obs)

        # Save belief and output
        pstate.append(model.qs[1][0])
        pact.append(q_pi[0])

    plt.plot([1-s for s in pstate], label="p(s)", linewidth=3, alpha=0.1, color='k')

plt.plot([s*1.1-0.05 for s in states], label="s", color="r", linewidth=3)
plt.xlabel("trial")
plt.ylim([-0.05, 1.05])
plt.title("Python")
conorheins commented 2 years ago

Hi @SamGijsen,

Thanks for your interest and the nice code snippet you shared -- I'm glad to see people are using pymdp in new ways like this. I don't think I've seen a reversal learning implementation in pymdp so far.

There are a few issues with the code you've shared that immediately pop out to me, and are almost certainly the cause of the 'stubborn', suboptimal behaviour you're observing in your agents. I'll describe them in order of decreasing importance:

  1. "Learning across trials" in the sense of active inference as implemented in MATLAB/DEM, typically refers to the accumulation of Dirichlet hyperparameter counts across trials (which I believe is the inner loop over T in your code). This means that, when you instantiate your agent (model), you want to pass in a pD argument to your model constructor line. This pD is a prior vector of Dirichlet parameters. It is a conjugate prior over the initial state prior (the D vector). You can initialize this to be something small and relatively 'unconfident' in the flatness of the state prior, e.g. [0.1 0.1]. Then, within each trial (after the second call to model.infer_states()), you can use the function model.update_D(qs_t0 = model. qs) to update the pD hyperparameters and the initial state prior D using the posterior over hidden states -- the output of model.infer_states(). Finally, at the beginning of each next loop, you should use model.reset() to reset the trial-relative internal clock of the agent to t= 0, and to re-initialize the posterior to a flat distribution over hidden states. The model's D vector should however reflect the expected value of pD (model.D = utils.norm_dist_obj_arr(model.pD)), which is being accumulated across trials. Does that make sense?
  2. That line where you "reset" the agent's perception at the beginning of each trial model.infer_states([2,0]) # 2 = neutral reward, 0 = start state observation is also prone to issues if you don't reset() the agent. This is because variational inference in this simple, VANILLA inference scheme (which is the default unless you specify marginal message -passing) ends up being a "Bayesian average" of a forwards message from the past (the past posterior, integrated with the transition dynamics or B array and the previous action) and the likelihood (the observation and the sensory likelihood or A array). This means that the posterior over states may not be peaked around being in the "center" (unchosen) state, because it may be averaged with the past posterior. This may or may not be the case, depending on how your B and A matrices look. But I just flag it, because in theory at the beginning of each trial, the agent should be "wiped clean" in terms of its beliefs and internal clock. The only thing that should be accumulated/correlated over trials, is the ever-accumulating pD vector.

I hope this helps! Cheers, Conor

SamGijsen commented 2 years ago

Thanks you for your quick and elaborate reply @conorheins - that's very much appreciated.

  1. The reason I initially left out 'prior state learning'/pD is because I didn't see a package-native way to allow for flexible reversal learning this way. Implementing my interpretation of your suggestion as follows in pseuo-code:
    
    pD = D.copy()
    pD[1] = pD[1]*0.1
    model = Agent(A=A, B=B, C=C, D=D, pD=pD, policy_len=1, action_selection=AS)

for trial in range(number_of_trials): model.reset() model.D = utils.norm_dist_obj_arr(model.pD) ... model.update_D(qs_t0 = model.qs)


ultimately yields a pD distribution of something akin to [\alpha1, \alpha2] ~= [19,2] after the initial 20 trials. However, after 20 more trials in the other state, we approach a [20,20] distribution. Thus, it takes many trials before the expectation of pD starts to align with the true but unknown hidden state. Please let me know if I misunderstood the suggestion in any way. Alternatively, I'm guessing some additional learning algorithm or even simple forgetting on the pD parameters to keep precision relatively low should do reasonably well in combination with setup described above.

2. Thank you for the explanation, the agent dynamics are indeed slightly more sensible when 'resetting' the agent at the start of each trial, but as per above unfortunately doesn't address the sluggish reversal learning. I've tried to run the MMP scheme instead, but this errors on `infer_policies()` (`AttributeError: 'NoneType' object has no attribute 'dtype'`) but I haven't figured out yet what exactly happens here.

Offtopic: running "stochastic" action selection errors for me when using the current version. I believe it's because actions are sampled even if num_controls=1, in which case utils.sample() attempts to squeeze an array of size 1. Reverting to deterministic selection if `num_controls==1` has fixed this for me. 
(control.sample_action() line 565 `if action_selection == 'deterministic' or num_controls[factor_i]==1:`) Happy to submit a PR if that's useful, otherwise I believe it's just this line of code.
conorheins commented 1 year ago

Hi @SamGijsen, sure thing -- happy to help out, especially if it leads to a possible improvement of pymdp :) Apologies in advance for my delay in responding, working on pymdp is not my official job so I only will have time to respond to issues/discussion points when I find time here and there.

However, after 20 more trials in the other state, we approach a [20,20] distribution. Thus, it takes many trials before the expectation of pD starts to align with the true but unknown hidden state.

This is an astute observation -- basically, you need to "dig yourself" out of the [19 2] prior belief about D, and pass through [20, 20] before you can reversal-learn and end up with an accurate initial state prior that reflects the reversed hidden state of the world (something like [2 19]). This makes me curious though, how the reversal learning avoids this sort of sluggishness in the DEM / MATLAB code. Could you perhaps share your MATLAB code as well, and indicate the version of spm_MDP_VB_X.m that you're using? Last I checked, there was no equivalent of the additional learning algorithm /pD "forgetting" as you propose in the original MATLAB spm_MDP_VB_X.m, but maybe my memory is not serving me right. In this case, I would expect the MATLAB version to exhibit the same sort of sluggishness. I'm assuming, as with the pymdp version of your code, policy-selection is 1-step ahead (size(V,1) == 1), each trial lasts one timestep (T = 1), and the learning rates/initial state priors over d are identical to how you've initialized them in pymdp?

Side-note: I know there is this paper by @AnnaCSales et al. where they implement a sort of 'decay' rate on learning the Dirichlet hyperparameters (consult equation for updating d on page 8 of the paper, and Anna Sales' accordingly modified version of spm_MDP_VB_X.m here, it looks like line 281 is where that decay rate is implemented when it comes to d learning), but I'm pretty sure the 'canonical' version of spm_MDP_VB_X.m doesn't implement this sort of flexibility in learning hyperparameters.

I've tried to run the MMP scheme instead, but this errors on infer_policies() (AttributeError: 'NoneType' object has no attribute 'dtype') but I haven't figured out yet what exactly happens here.

The choice to use MMP shouldn't really be appropriate here, unless I'm making incorrect assumptions about the temporal structure of each trial in your task. If each trial lasts more than 1 timestep (T > 1, see above), then maybe the use of MMP (and the resulting "pre- and post-dictive beliefs" you get from that type of posterior) is actually the source of the discrepancy. In the case that each trial actually lasts multiple timesteps, then if the agent is using MMP for inference, then the posterior belief over hidden states used to update D will be different than the one used for 1-timestep-per-trial learning, the case where you'd use vanilla style inference. But other than that, the error you're encountering with MMP is interesting, I don't know why that's happening off the top of my head, but it may be due to your policy_depth or backwards_horizon parameters being too short. I will add an issue to create better error messages for when you're trying to MMP with an inappropriately-parameterized Agent() instance.

Offtopic: running "stochastic" action selection errors for me when using the current version. I believe it's because actions are sampled even if num_controls=1, in which case utils.sample() attempts to squeeze an array of size 1. Reverting to deterministic selection if num_controls==1 has fixed this for me. (control.sample_action() line 565 if action_selection == 'deterministic' or num_controls[factor_i]==1:) Happy to submit a PR if that's useful, otherwise I believe it's just this line of code.

Thanks a lot for noticing this! Can you create an issue describing this issue (you can reference your comment in this thread as well)?

SamGijsen commented 1 year ago

Apologies in advance for my delay in responding, working on pymdp is not my official job so I only will have time to respond to issues/discussion points when I find time here and there.

No worries, that's very understandable, I'm thankful for the comments. :)

I'm assuming, as with the pymdp version of your code, policy-selection is 1-step ahead (size(V,1) == 1), each trial lasts one timestep (T = 1), and the learning rates/initial state priors over d are identical to how you've initialized them in pymdp?

Ah, this might be the point of discrepancy as you suggest. My understanding of modeling the current task in the MATLAB toolbox would require T=2 because the policy needs to transition from a certain state into another, requiring a "start state" and an "ending state" (modeled as A[1] in the opening post). The policy depth would indeed be one. This should indeed then give rise to pre- and post-dictive beliefs. My approach should be similar to this repo from this paper, basically no learning but only state inference (Im using learning here in the sense of accruing Dirichlet-parameters). I'll share some example MATLAB code illustrating my understanding below. However, after taking a look at the paper you linked (thank you for this!) I must say its method to flexibly adapt 'd'-parameters seems a much more appropriate way to deal with reversal learning (rather than pure state inference) and is what I was hinting at previously. I'm going to have a look at how such a model behaves.

Thanks a lot for noticing this! Can you create an issue describing this issue (you can reference your comment in this thread as well)?

Sure, will do!

MATLAB code

Trial structure with a for-loop:

mdp = {};
num_trials = 40;
z = [1:20];
pstate = [];
true_state = [];

for i = 1:num_trials

    MDP = simple_model(0.9, 0.9);

    if i>1
        MDP.D{2} = X;
    end

    if sum(z==i) == 1
        [MDP.s] = [1 2]';
    else % switch state
        [MDP.s] = [1 1]';
    end

    MDP_p = spm_MDP_VB_X(MDP);
    mdp{i} = MDP_p;
    X = MDP_p.X{2}(:,end);

    true_state = [ true_state MDP.s(2) ];
    pstate = [ pstate mdp{i}.X{2}(2,2) ];
end

plot(pstate, 'k'); hold on
plot(true_state-1, 'r')

SPM model:

function mdp = simple_model(r1, r2)

D{1} = [1 0 0]'; % Choice State {start, choose B1, choose B2}
D{2} = [1 1]'; % Context (B1 better, B2 better)

% P(o|s)
Nf = numel(D);

for f = 1:Nf
    Ns(f) = numel(D{f});
end

No = [3 3]; % 3 rewards [Pos, Neg, Neutral], 3 choice observations

Ng = numel(No);

for g = 1:Ng
    A{g} = zeros([No(g),Ns]); 
end

               %S, b1, b2
A{1}(:,:,1) = [0, 1-r1, r2; % Pos reward
               0, r1, 1-r2; % neg reward
               1, 0, 0];% Neutral reward

A{1}(:,:,2) = [0, r1, 1-r2; % Pos reward
               0, 1-r1, r2; % neg reward
               1, 0, 0];% Neutral reward

A{2}(:,:,1) = eye(3);
A{2}(:,:,2) = eye(3);

a{1} = A{1};
a{1}(1:2,:,:) = a{1}(1:2,:,:)*2;
a{1}(3,:,:) = a{1}(3,:,:)*10000;
a{2} = A{2}*10000;

% B{state factor}(state at time tau+1, state at time tau, action number) 
for f = 1:Nf
        B{f} = zeros(Ns(f));
end

% action 1: move from start to button 1
% a2: start to b2
for c = 1:2
    B{1}(1,:,c) = ones(1,3);
end

B{1}(:,1,1) = [0 1 0]'; % start -> b1
B{1}(:,1,2) = [0 0 1]'; % start -> b2

B{2} = eye(2);

% Rewards
      %t1 2 
C{1} = [0 2.5; % pos
        0 -2.5; % neg
        0 0 ];% neu

C{2} = zeros(3);

% press b1 or b2
U(:,:,1) = [1 2]; 
U(:,:,2) = [1 1];

mdp.A = A;
%mdp.a = a; % likelihood learning on/off
mdp.B = B;
mdp.C = C;
mdp.D = D;
mdp.U = U;
mdp.T = 2;
mdp.Aname = ["RewardsA", "ChoicesA"];
mdp.Bname = ["ChoicesB", "Context"];
mdp.label.modality = {"Choice State"};
end