RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
577 stars 54 forks source link

Issue: vmapped CartPole input shape does not match #37

Closed DriesSmit closed 1 year ago

DriesSmit commented 1 year ago

Hello there. I am trying to run a vmapped CartPole step function. My environment state inputs are of the shape:

env_state:
[executor/0] x:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] x_dot:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] theta:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>
[executor/0] theta_dot:  Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=2/0)>

When I run jnp.array([env_state.x, env_state.x_dot, env_state.theta, env_state.theta_dot]) on the state, before the environment step, and get out: Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=2/0)>

However when I try to run the step function I get:

obs, env_state, rewards, done, _ = self.env.step(key_step, env_state, action, self.env_params)
[executor/0]   File "/mava/lib/python3.8/site-packages/gymnax/environments/environment.py", line 38, in step
[executor/0]     obs_st, state_st, reward, done, info = self.step_env(
[executor/0]   File "/mava/lib/python3.8/site-packages/gymnax/environments/classic_control/cartpole.py", line 83, in step_env
[executor/0]     lax.stop_gradient(self.get_obs(state)),
[executor/0]   File "/mava/lib/python3.8/site-packages/gymnax/environments/classic_control/cartpole.py", line 108, in get_obs
[executor/0]     return jnp.array([state.x, state.x_dot, state.theta, state.theta_dot])
[executor/0]   File "/mava/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1889, in array
[executor/0]     out = stack([asarray(elt, dtype=dtype) for elt in object])
[executor/0]   File "/mava/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 1634, in stack
[executor/0]     raise ValueError("All input arrays must have the same shape.")
[executor/0] ValueError: All input arrays must have the same shape.

Do you have any idea what might be causing this issue? Is the shapes somehow changing inside the step function? Thanks.

DriesSmit commented 1 year ago

Apologies, it was not to do with the environment. It was because I was passing an action logit array of size 2 to the environment.