environment made by RolloutWrapper doesn't reflect num_env_step variable which we put into RolloutWrapper:
Code for reproduction
from gymnax.experimental import RolloutWrapper
import jax
ENV_NUM = 3
manager = RolloutWrapper(None, env_name='CartPole-v1', num_env_steps=100)
rng, rollout_rng = jax.random.split(jax.random.key(0))
rollout_rng = jax.random.split(rollout_rng, ENV_NUM)
obs, action, reward, next_obs, done, cum_ret = manager.batch_rollout(rollout_rng, None)
print(done.shape) # it should print (3, 100), but the result is (3, 500)
Why this bug happened?
RolloutWrapper.single_rollout() puts self.env_params.max_steps_in_episode in jax.lax.scan() instead of self.num_env_steps (gymnax/experimental/rollout.py#L94).
Issue #79
environment made by RolloutWrapper doesn't reflect
num_env_step
variable which we put intoRolloutWrapper
:Code for reproduction
Why this bug happened?
RolloutWrapper.single_rollout()
putsself.env_params.max_steps_in_episode
injax.lax.scan()
instead ofself.num_env_steps
(gymnax/experimental/rollout.py#L94).num_env_steps
as sames as enviornment'smax_steps_in_episode
for testing this feature (gymnax/tests/wrappers/test_evaluator.py).What is fixed in this PR?
self.env_params.max_steps_in_episode
toself.num_env_steps
inRolloutWrapper.single_rollout()
>jax.lax.scan()
.