google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Ant (positional) initialization often inevitably leads to unhealthy state / early termination with a few timesteps #429

Closed DramaCow closed 7 months ago

DramaCow commented 7 months ago

I'm struggling to get good performance on Ant (positional). It seems like many episodes terminate early due to failing the is_healthy check.

Test case: I reset 256000 Ant environments and step a few times with nop actions (all zero).

import jax
import jax.numpy as jnp
from brax import envs

env = envs.get_environment("ant", backend="positional")

print(env._healthy_z_range)

def unhealthy(state):
    """Healthy check taken verbatim from ant brax.envs.ant.Ant"""
    pipeline_state = state.pipeline_state
    min_z, max_z = env._healthy_z_range
    is_healthy = jnp.where(pipeline_state.x.pos[0, 2] < min_z, x=0.0, y=1.0)
    is_healthy = jnp.where(
        pipeline_state.x.pos[0, 2] > max_z, x=0.0, y=is_healthy
    )
    return (1 - is_healthy).astype(jnp.bool_)

def f(rng, _):
    rng, _rng = jax.random.split(rng)

    init_state = jax.vmap(env.reset)(jax.random.split(_rng, 256))
    next_state = jax.vmap(env.step)(init_state, jnp.zeros((256, env.action_size)))
    next_state = jax.vmap(env.step)(next_state, jnp.zeros((256, env.action_size)))
    next_state = jax.vmap(env.step)(next_state, jnp.zeros((256, env.action_size)))
    next_state = jax.vmap(env.step)(next_state, jnp.zeros((256, env.action_size)))

    init_z = init_state.pipeline_state.x.pos[:, 0, 2]
    delta_z = init_z - next_state.pipeline_state.x.pos[:, 0, 2]

    return rng, (jax.vmap(unhealthy)(next_state), init_z, delta_z)

rng = jax.random.PRNGKey(0)
_, (is_unhealthy, init_z, delta_z) = jax.lax.scan(f, rng, None, length=1000)

init_z = init_z[is_unhealthy]
delta_z = delta_z[is_unhealthy]

print(is_unhealthy.sum(), init_z, delta_z)

After:

This problem doesn't seem to occur for other environments (Hopper, Humanoid).

Is this a bug? I don't think this problem occurred for Ant in older versions. Looking at the deltas it seems like Ant is initialized with a z-velocity pushing the agent outside the healthy range, which difficult to overcome.

Ran using Brax version 0.9.3.

btaba commented 7 months ago

Hi @DramaCow , I'm not able to reproduce your issue. Here's what I ran:

env = envs.create(env_name=env_name, backend='positional')

jit_env_reset = jax.jit(jax.vmap(env.reset))
jit_env_step = jax.jit(jax.vmap(env.step))

rollout = []
rng = jax.random.PRNGKey(seed=1)
rng = jax.random.split(rng, 1024)
state = jit_env_reset(rng=rng)
for _ in range(1000):
  rollout.append(state.pipeline_state)
  state = jit_env_step(state, jp.zeros((1024, env.sys.act_size())))

def unhealthy(state):
    pipeline_state = state.pipeline_state
    min_z, max_z = env._healthy_z_range
    is_healthy = jp.where(pipeline_state.x.pos[0, 2] < min_z, x=0.0, y=1.0)
    is_healthy = jp.where(
        pipeline_state.x.pos[0, 2] > max_z, x=0.0, y=is_healthy
    )
    return (1 - is_healthy).astype(jp.bool_)

jax.vmap(unhealthy)(state).sum()

I get 0 at the end