infer-actively / pymdp

A Python implementation of active inference for Markov Decision Processes
MIT License
420 stars 83 forks source link

transition matrix B form #53

Closed mklingebiel closed 2 years ago

mklingebiel commented 2 years ago

Hey there, First of all, I'm grateful for your work and your effort to make active inference more accessible for everyone. I'm quite new to this approach, so I started playing around with the provided notebooks. I was wondering why the (controllable) transition matrix B (for the Agent class inside pymdp.agent) accepts direct state interaction only, i.e. it has to be in the form s^i_{t+1} = B(s^it, a) and not in a form with internal state interactions like s^i{t+1} = B(s^i_t, s^j_t ,a), where s^i, s^j refer to different hidden states. Am I wrong here? Because imho it shouldn't be a problem, since you can always define a world state \tilde{s}(s^i, s^j), but makes your program more complicated. Best, Martin

conorheins commented 2 years ago

Dear Martin,

Thanks for your comment and your interest in the pakage.

I was wondering why the (controllable) transition matrix B (for the Agent class inside pymdp.agent) accepts direct state interaction only, i.e. it has to be in the form s^i_{t+1} = B(s^it, a) and not in a form with internal state interactions like s^i{t+1} = B(s^i_t, s^j_t ,a), where s^i, s^j refer to different hidden states. Am I wrong here? Because imho it shouldn't be a problem, since you can always define a world state \tilde{s}(s^i, s^j), but makes your program more complicated.

Yes that's a good point -- as you noticed, the different hidden state factors (e.g. level s_i of state factor A, and level s_j of state factor B) do not interact in the generative model. The reason this is, is that it would make the belief updating more complicated, as the hidden state factors would no longer be conditionally independent of eachother. Right now, hidden state factors / control factors only 'interact' conditionally, via the observation likelihood. If we encoded interactions among hidden state factors directly in the transition likelihood, the approximate posterior would lose its accuracy (which assumes the full posterior factorizes across hidden state factors and timesteps, allowing us to store the posterior beliefs qs using the same object array data structure that we use for e.g. the B matrices and D vectors).

However, now that I think about it, there actually may be a way to include these 'inter-state-factor' interactions while still maintaining a factorized posterior (although with the caveat that now the mean-field approximation is even poorer than it already is). But we'd have to think about how to include these "B-matrix interactions" in the belief updating algorithms. Right now, the belief updating algorithms generally involve looping over posterior marginals (one marginal per hidden state factor), and then incorporating the mean-field effect of all the marginals (besides the one you're currently looping over) through the observation likelihood, and then combining this mean-field likelihood term with the prior for that marginal, which is a function of the transition likelihood. See for example run_fpi.py in the pymdp.inference.algos module, which explicitly does this mean field scheme.

In the scheme you're proposing, the prior for that marginal would actually have to incorporate terms from other marginals, since they are now directly influencing eachother through the B matrix. Perhaps it reduces to just a few extra terms in the belief updating algorithm that correspond to B-matrix-mediated messages from those other marginals, but my suspicion is that the violation of conditional independence in the generative model makes things a bit more complicated than that...

Anyway, the short answer is, you're correct we don't have that right now -- but my suspicion is that it may be complicated to include. But I will think more about it and perhaps try to implement it as an additional feature.