RobertTLange / gymnax

RL Environments in JAX 🌍
Apache License 2.0
577 stars 54 forks source link

num_env_step not work in RolloutWrapper #79

Open jinPrelude opened 2 weeks ago

jinPrelude commented 2 weeks ago

Issue

environment made by RolloutWrapper doesn't reflect num_env_step variable which we put into RolloutWrapper:

Code for reproduction

from gymnax.experimental import RolloutWrapper
import jax

ENV_NUM = 3
manager = RolloutWrapper(None, env_name='CartPole-v1', num_env_steps=100)

rng, rollout_rng = jax.random.split(jax.random.key(0))
rollout_rng = jax.random.split(rollout_rng, ENV_NUM)
obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(rollout_rng, None)
print(done.shape) # it should print (3, 100), but the result is (3, 500)