Closed remmyzen closed 1 month ago
Hi @remmyzen!
The issue seems to be with the flattening of observations in the policy network. There should be an easy way to achieve what you want to do using something like vmap on make_act
, I will address this in the future. For now, I believe an easy workaround is to vmap the policy along the first axis of the train state as well, like this:
@jax.jit
@jax.vmap
def policy(train_state, obs, rng):
p = algo.make_act(train_state)
return p(obs, rng)
Since you want to run all policies, you have to vmap step and reset as well, and take some care with the loop termination:
rng = jax.random.PRNGKey(0)
env, params = gymnax.make(env_str)
step = jax.vmap(env.step, in_axes=(0, 0, 0, None))
step = jax.jit(step)
rng, rng_init = jax.random.split(rng)
rng_init = jax.random.split(rng_init, len(keys))
obs, state = jax.vmap(env.reset, in_axes=(0, None))(rng_init, params)
episode_return = 0
done = jax.numpy.zeros((len(keys),), dtype=bool)
while not jax.numpy.all(done):
rng, rng_action, rng_step = jax.random.split(rng, 3)
rng_action = jax.random.split(rng_action, len(keys))
rng_step = jax.random.split(rng_step, len(keys))
action = policy(train_state, obs, rng_action)
obs, state, reward, new_done, info = step(rng_step, state, action, params)
episode_return += reward * (1 - done)
done = done | new_done
print(f"Return achieved in one episode of {env_str}: {episode_return}")
Please note that the rollout using a for-loop is only meant for didactic purposes, using jit in combination with jax.lax.while_loop
is much faster.
Let me know if this works for you!
Hi @keraJLi,
Thank you for your reply. It works now!
Hi,
Thanks for the nice repo! I am trying to run the train vmapped policies example in the
README.md
but always face an error. Could you give an example?Basically, I am trying to combine the examples in the
README.md
and theexamples/rejax_tour.ipynb
notebook.It gives an error on the
action = policy(obs, rng_action)
line with this message:I also tried to vmap the
reset
such that theobs.shape
is(50,4)
and split therng_action
into 50 but then the error showed got shape (200, 64) possibly due to the squeeze in thepolicy
function.Could you advise on how to run vmapped policy properly?
Thank you.