infer-actively / pymdp

A Python implementation of active inference for Markov Decision Processes
MIT License
419 stars 83 forks source link

JAX backend, sparse likelihood dependencies, sophisticated inference, inductive inference #132

Closed conorheins closed 4 weeks ago

conorheins commented 4 weeks ago

This pull request introduces a jax re-implementation of most of the active inference-related functionality provided by the pymdp package. This also introduces new features (including 3 types of inference algorithms: variational message passing, marginal message passing, and online variational filtering), and the ability to simulate active inference agents in parallel using an additional batch_dimension that is appended to the leading axes of all parameter tensors, actions, posterior beliefs, and observations.

As an example, a typical A array in the numpy backend might have the following shapes

>>> A[modality_m].shape
>>> [num_obs[modality_m], num_states[0], num_states[1], ...., num_states[-1]]

With the new changes, now a given A[modality_m] tensor will have shape

>>> A[modality_m].shape
>>> [N, num_obs[modality_m], num_states[0], num_states[1], ...., num_states[-1]]

where N is an additional batch dimension that indicates the number of generative models / agents one is parallelizing active inference processes across.

Most importantly, the Agent API has been amended in the following ways:

Other features: