google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Leaked Trace Error for generalized backend humanoid env. #395

Closed Wedthree closed 8 months ago

Wedthree commented 9 months ago

I meet a random leaked trace error when using humanoid env in my project. The error is thrown randomly. Therefore, I decide to use jax.config.update("jax_check_tracer_leaks", True) to find where the leakage happens. After some work, I find it origins from the brax env. The following the code to reproduce the error.

`import jax import jax.numpy as jnp from brax.envs.humanoid import Humanoid

jax.config.update("jax_check_tracer_leaks", True) humanoid_env = Humanoid(backend="generalized") state = humanoid_env.reset(jax.random.PRNGKey(0)) for i in range(10): print("+++++++++++++++++++++++") print(i) action = jnp.zeros((humanoid_env.action_size,)) state = humanoid_env.step(state, action) print(state.obs)`

When I use "postional" backend, the leaked trace error disappear. From the error log, the problem seems related to the "" see the log blow

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "", line 16, in state = humanoid_env.step(state, action) File "/venv/lib/python3.10/site-packages/brax/envs/", line 250, in step pipeline_state = self.pipeline_step(pipeline_state0, action) File "/venv/lib/python3.10/site-packages/brax/envs/", line 127, in pipeline_step return jax.lax.scan(f, pipeline_state, (), self._n_frames)[0] File "/venv/lib/python3.10/site-packages/brax/envs/", line 123, in f self._pipeline.step(self.sys, state, action, self._debug), File "/venv/lib/python3.10/site-packages/brax/generalized/", line 72, in step state = state.replace(qf_constraint=constraint.force(sys, state)) File "/venv/lib/python3.10/site-packages/brax/generalized/", line 238, in force qf_constraint = state.con_jac.T @ File "/venv/lib/python3.10/site-packages/jaxopt/_src/", line 136, in run return, hyperparams_proj, *args, kwargs) File "/venv/lib/python3.10/site-packages/jaxopt/_src/", line 355, in run return run(init_params, *args, *kwargs) File "/venv/lib/python3.10/site-packages/jaxopt/_src/", line 317, in _run opt_step = self.update(init_params, state, args, kwargs) File "/venv/lib/python3.10/site-packages/jaxopt/_src/", line 296, in update return f(params, state, hyperparams_prox, args, kwargs) File "/venv/lib/python3.10/site-packages/jaxopt/_src/", line 266, in _update_accel next_x, next_stepsize = self._iter(iter_num, y, y_fun_val, y_fun_grad, File "/venv/lib/python3.10/site-packages/jaxopt/_src/", line 224, in _iter next_x, next_stepsize = self._fista_line_search(self.maxls, x, x_fun_val, File "/venv/lib/python3.10/", line 142, in exit next(self.gen) Exception: Leaked trace MainTrace(2,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)> This DynamicJaxprTracer was created on line () <DynamicJaxprTracer 140108750238752> is referred to by <list 140109619717248>[5] <list 140109619717248> is referred to by <tuple 140109619392448>[1] <tuple 140109619392448> is referred to by <tuple 140109619386048>[0] <tuple 140109619386048> is referred to by <dict 140108750487680>[(((<function flatten_fun_nokwargs at 0x7f6e754bb520>, (PyTreeDef(((, ),)),)), (<function _argnums_partial at 0x7f6e754bbbe0>, ((2,), (<jax._src.api_util._HashableWithStrictTypeEquality object at 0x7f6d9c4d11b0>, <jax._src.api_util._HashableWithStrictTypeEquality object at 0x7f6d9c4d17b0>, <jax._src.api_util._HashableWithStrictTypeEquality object at 0x7f6d9c4d1180>))), (<function result_paths at 0x7f6e75514940>, ())), (), None, ((ShapedArray(float32[25]), ShapedArray(float32[])), TracingDebugInfo(traced_for='jit', func_src_info='_while_loop_scan at /venv/lib/python3.10/site-packages/jaxopt/_src/', arg_names=('init_val[0]', 'init_val[1]'), result_paths=None), <hashable with closure=()>), False, None, ((), (), False, 'allow', None, False, 'standard', None, True, False, False, None))] <dict 140108750487680> is referred to by <dict 140112385575488>[<weakref at 0x7f6d9c49a160; to 'function' at 0x7f6e71d9e3b0 (_while_loop_scan)>] <dict 140112385575488> is referred to by <WeakKeyDictionary 140112385358528>.data <WeakKeyDictionary 140112385358528> is referred to by <method 140112385576768> <method 140112385576768> is referred to by <function 140112385515600>.cache_clear <function 140112385515600> is referred to by jax._src.pjit._create_pjit_jaxpr

Do you have any suggestion on how to avoid such problem?

btaba commented 8 months ago

Hi @Wedthree , it's likely because jaxopt is jitting something, but you aren't running jitted functions. I would do:

jit_reset = jax.jit(humanoid_env.reset)
jit_step = jax.jit(humanoid_env.step)
state = jit_reset(jax.random.PRNGKey(0))
action = jnp.zeros((humanoid_env.action_size,))
for i in range(10):
  state = jit_step(state, action)