Closed OzanCatalVerses closed 3 months ago
Hard for me to track and assess all the changes to Agent, but the Distribution and model specification looks great!
None of this will work if input matrices already have a batch dimension ` # setup pytree leaves A, B, C, D, E, pA, pB, H, I A = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), A) B = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), B)
if pA is not None:
pA = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), pA)
if pB is not None:
pB = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), pB)
if C is not None:
C = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), C)
else:
C = [jnp.ones((batch_size, self.num_obs[m])) / self.num_obs[m] for m in range(self.num_modalities)]
if D is not None:
D = jtu.tree_map(lambda x: jnp.broadcast_to(x, (batch_size,) + x.shape), D)
else:
D = [jnp.ones((batch_size, self.num_states[f])) / self.num_states[f] for f in range(self.num_factors)]
if E is not None:
E = jnp.broadcast_to(E, (batch_size,) + E.shape)
else:
E = jnp.ones((batch_size, len(self.policies))) / len(self.policies)`
I suggest the logic for broadcasting to be specified inside the Distribution class. Some users might not want to depend on that and should provide directly all the parameters in the correct shape.
Also note that all the fields of the agent which are not defined as static, need to have a batch dimension. So, I and H lists should also be checked for consistency.
@dimarkov ive removed this as default behavior, with an optional flag set to false. but just seen your comment you'd prefer to be in distribution, which i can do instead
Introduces a new Distribution object with (optionally) named axes and indices as well as some changes to the Agent object to start supporting the new distributions.