This pull request introduces a jax re-implementation of most of the active inference-related functionality provided by the pymdp package. This also introduces new features (including 3 types of inference algorithms: variational message passing, marginal message passing, and online variational filtering), and the ability to simulate active inference agents in parallel using an additional batch_dimension that is appended to the leading axes of all parameter tensors, actions, posterior beliefs, and observations.
As an example, a typical A array in the numpy backend might have the following shapes
where N is an additional batch dimension that indicates the number of generative models / agents one is parallelizing active inference processes across.
Most importantly, the Agent API has been amended in the following ways:
the Agent object is now an instance of an equinox.Module, which means agents can be treated as pytrees. Using the Agent class thus requires both jax and the equinox package to be included in the requirements.
methods are now vmap decorated, so that the methods of an Agent can be used to simulate N agents in parallel. This also means the methods are much more functional, with fewer in-place operations on object properties, as done in the numpy version of Agent
Other features:
allows subsets of hidden state factors to influence subsets of observation modalities (see this branch)
interactions among hidden state factors in the transition dynamics (see this branch)
preliminary ability to perform parameter estimation using numpyro and @dimarkov's pybefit package. This required adding numpyro, optax and arviz to the requirements of the package. See the Model Inversion Notebook for a worked example of fitting the parameters of a T-Maze navigating agent to simulated pairs of (action, observation) data. Warning: parameter estimation is still buggy and not thoroughly tested. We find it is currently error-prone while fitting active inference agents equipped with advanced features like learning of A and B. Sometimes we see nan-valued gradients when using numpyro's svi routine, meaning this is a WIP feature.
This pull request introduces a
jax
re-implementation of most of the active inference-related functionality provided by thepymdp
package. This also introduces new features (including 3 types of inference algorithms: variational message passing, marginal message passing, and online variational filtering), and the ability to simulate active inference agents in parallel using an additionalbatch_dimension
that is appended to the leading axes of all parameter tensors, actions, posterior beliefs, and observations.As an example, a typical A array in the numpy backend might have the following shapes
With the new changes, now a given
A[modality_m]
tensor will have shapewhere
N
is an additional batch dimension that indicates the number of generative models / agents one is parallelizing active inference processes across.Most importantly, the
Agent
API has been amended in the following ways:equinox.Module
, which means agents can be treated as pytrees. Using theAgent
class thus requires bothjax
and theequinox
package to be included in the requirements.vmap
decorated, so that the methods of anAgent
can be used to simulateN
agents in parallel. This also means the methods are much more functional, with fewer in-place operations on object properties, as done in the numpy version ofAgent
Other features:
numpy
backend (thanks to @tverbele's fork)jax
based on the implementation introduced in @tverbele's forknumpyro
and @dimarkov'spybefit
package. This required addingnumpyro
,optax
andarviz
to the requirements of the package. See the Model Inversion Notebook for a worked example of fitting the parameters of a T-Maze navigating agent to simulated pairs of (action, observation) data. Warning: parameter estimation is still buggy and not thoroughly tested. We find it is currently error-prone while fitting active inference agents equipped with advanced features like learning ofA
andB
. Sometimes we seenan
-valued gradients when usingnumpyro
'ssvi
routine, meaning this is a WIP feature.