infer-actively / pymdp

A Python implementation of active inference for Markov Decision Processes
MIT License
450 stars 89 forks source link

Feature/distribution api #136

Closed OzanCatalVerses closed 3 months ago

OzanCatalVerses commented 3 months ago

Introduces a new Distribution object with (optionally) named axes and indices as well as some changes to the Agent object to start supporting the new distributions.

tverbele commented 3 months ago

Hard for me to track and assess all the changes to Agent, but the Distribution and model specification looks great!

dimarkov commented 3 months ago

None of this will work if input matrices already have a batch dimension ` # setup pytree leaves A, B, C, D, E, pA, pB, H, I A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), A) B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), B)

    if pA is not None:
        pA = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), pA)

    if pB is not None:
        pB = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), pB)

    if C is not None:
        C = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), C)
    else:
        C = [jnp.ones((batch_size, self.num_obs[m])) / self.num_obs[m] for m in range(self.num_modalities)]

    if D is not None:
        D = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), D)
    else:
        D = [jnp.ones((batch_size, self.num_states[f])) / self.num_states[f] for f in range(self.num_factors)]

    if E is not None:
        E = jnp.broadcast_to(E, (batch_size,) + E.shape)
    else:
        E = jnp.ones((batch_size, len(self.policies))) / len(self.policies)`

I suggest the logic for broadcasting to be specified inside the Distribution class. Some users might not want to depend on that and should provide directly all the parameters in the correct shape.

dimarkov commented 3 months ago

Also note that all the fields of the agent which are not defined as static, need to have a batch dimension. So, I and H lists should also be checked for consistency.

alec-tschantz commented 3 months ago

@dimarkov ive removed this as default behavior, with an optional flag set to false. but just seen your comment you'd prefer to be in distribution, which i can do instead