Closed DramaCow closed 7 months ago
Hi @DramaCow , I'm not able to reproduce your issue. Here's what I ran:
env = envs.create(env_name=env_name, backend='positional')
jit_env_reset = jax.jit(jax.vmap(env.reset))
jit_env_step = jax.jit(jax.vmap(env.step))
rollout = []
rng = jax.random.PRNGKey(seed=1)
rng = jax.random.split(rng, 1024)
state = jit_env_reset(rng=rng)
for _ in range(1000):
rollout.append(state.pipeline_state)
state = jit_env_step(state, jp.zeros((1024, env.sys.act_size())))
def unhealthy(state):
pipeline_state = state.pipeline_state
min_z, max_z = env._healthy_z_range
is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, x=0.0, y=1.0)
is_healthy = jp.where(
pipeline_state.x.pos[0, 2] > max_z, x=0.0, y=is_healthy
)
return (1 - is_healthy).astype(jp.bool_)
jax.vmap(unhealthy)(state).sum()
I get 0 at the end
I'm struggling to get good performance on Ant (positional). It seems like many episodes terminate early due to failing the is_healthy check.
Test case: I reset 256000 Ant environments and step a few times with nop actions (all zero).
After:
This problem doesn't seem to occur for other environments (Hopper, Humanoid).
Is this a bug? I don't think this problem occurred for Ant in older versions. Looking at the deltas it seems like Ant is initialized with a z-velocity pushing the agent outside the healthy range, which difficult to overcome.
Ran using Brax version 0.9.3.