infer-actively / pymdp

A Python implementation of active inference for Markov Decision Processes
MIT License
419 stars 83 forks source link

Feature/complex action dependency for Jax agent (copy of #139) #143

Closed conorheins closed 3 weeks ago

conorheins commented 3 weeks ago

Copy of PR #139 that I had ability to push to (since @ran-weii's PR was from his private fork) Ran's original PR description quoted below:

Implementation of complex action dependency for Jax agent with tests and notebook demos.

Complex action dependency happens with states depend on multiple actions or no actions. We handle this by accepting state factors that have multiple lagging action dimensions into the Agent class, we then flatten the action dimensions on init to create new actions and corresponding policies. All subsequent computations are performed for the flattened transition tensors and actions. Two functions are added to encode and decode complex actions and flat actions.

NOTE: this method currently does not work with marginal action sampling. Because if there are multiple state factors that depend on multiple different actions, converting the actions to their corresponding combinations need to be coupled. Marginal action sampling will decouple this and output incorrect action encoding and decoding.

Main feature list:

  • B tensor flattening and new policy construction here
  • Encode and decode between complex/multi actions and flattened actions here