RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
577 stars 54 forks source link

Pendulum-1, MountainCarContinuous-v0 and Reacher-misc return non-squeezed reward #43

Closed keraJLi closed 1 year ago

keraJLi commented 1 year ago

Out of all environments, Pendulum-1, MountainCarContinuous-v0 and Reacher-misc return a jax array with shape (1, ) as a reward. This is inconsistent with all other environments, which return an array with shape (). This can lead to unexpected shaping errors, for example consider a case like this

num_envs = 3
weights = jnp.arange(num_envs)
rewards, ... = jax.vmap(env.step, ...)(actions)
weighted_rewards = weights * rewards

If the reward returned by the environment has shape (1, ), the the result of vmapping will have shape (3, 1) instead of (3, ), and therefore weighted_rewards will have shape (3, 3) instead of (3, ).