This branch started from the files pymdp.jax.learning.py and test/test_learning_jax.py from the master branch: updates in pymdp.jax.learning.py are lines 70-107 and in test_learning_jax.py are lines 117-EOF.
added unit tests to compare performance against the numpy equivalent of these methods.
Added the factors_to_update functionality to only update certain state modalities of the transition matrix + tested against their numpy equivalent
Added docstring to update_state_transition_dirichlet
pymdp.jax.learning.py
andtest/test_learning_jax.py
from the master branch: updates inpymdp.jax.learning.py
are lines 70-107 and in test_learning_jax.py are lines 117-EOF.factors_to_update
functionality to only update certain state modalities of the transition matrix + tested against their numpy equivalentupdate_state_transition_dirichlet