google / brax

Massively parallel rigidbody physics simulation on accelerator hardware.
Apache License 2.0
2.14k stars 234 forks source link

Leaked Trace Error for generalized backend humanoid env. #395

Closed Wedthree closed 8 months ago

Wedthree commented 9 months ago

I meet a random leaked trace error when using humanoid env in my project. The error is thrown randomly. Therefore, I decide to use jax.config.update("jax_check_tracer_leaks", True) to find where the leakage happens. After some work, I find it origins from the brax env. The following the code to reproduce the error.

`import jax import jax.numpy as jnp from brax.envs.humanoid import Humanoid

jax.config.update("jax_check_tracer_leaks", True) humanoid_env = Humanoid(backend="generalized") state = humanoid_env.reset(jax.random.PRNGKey(0)) for i in range(10): print("+++++++++++++++++++++++") print(i) action = jnp.zeros((humanoid_env.action_size,)) state = humanoid_env.step(state, action) print(state.obs)`

When I use "postional" backend, the leaked trace error disappear. From the error log, the problem seems related to the "proximal_gradient.py" see the log blow

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "test.py", line 16, in state = humanoid_env.step(state, action) File "/venv/lib/python3.10/site-packages/brax/envs/humanoid.py", line 250, in step pipeline_state = self.pipeline_step(pipeline_state0, action) File "/venv/lib/python3.10/site-packages/brax/envs/base.py", line 127, in pipeline_step return jax.lax.scan(f, pipeline_state, (), self._n_frames)[0] File "/venv/lib/python3.10/site-packages/brax/envs/base.py", line 123, in f self._pipeline.step(self.sys, state, action, self._debug), File "/venv/lib/python3.10/site-packages/brax/generalized/pipeline.py", line 72, in step state = state.replace(qf_constraint=constraint.force(sys, state)) File "/venv/lib/python3.10/site-packages/brax/generalized/constraint.py", line 238, in force qf_constraint = state.con_jac.T @ pg.run(jp.zeros_like(b)).params File "/venv/lib/python3.10/site-packages/jaxopt/_src/projected_gradient.py", line 136, in run return self._pg.run(init_params, hyperparams_proj, *args, kwargs) File "/venv/lib/python3.10/site-packages/jaxopt/_src/base.py", line 355, in run return run(init_params, *args, *kwargs) File "/venv/lib/python3.10/site-packages/jaxopt/_src/base.py", line 317, in _run opt_step = self.update(init_params, state, args, kwargs) File "/venv/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py", line 296, in update return f(params, state, hyperparams_prox, args, kwargs) File "/venv/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py", line 266, in _update_accel next_x, next_stepsize = self._iter(iter_num, y, y_fun_val, y_fun_grad, File "/venv/lib/python3.10/site-packages/jaxopt/_src/proximal_gradient.py", line 224, in _iter next_x, next_stepsize = self._fista_line_search(self.maxls, x, x_fun_val, File "/venv/lib/python3.10/contextlib.py", line 142, in exit next(self.gen) Exception: Leaked trace MainTrace(2,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=2/0)> This DynamicJaxprTracer was created on line test.py:16 () <DynamicJaxprTracer 140108750238752> is referred to by <list 140109619717248>[5] <list 140109619717248> is referred to by <tuple 140109619392448>[1] <tuple 140109619392448> is referred to by <tuple 140109619386048>[0] <tuple 140109619386048> is referred to by <dict 140108750487680>[(((<function flatten_fun_nokwargs at 0x7f6e754bb520>, (PyTreeDef(((, ),)),)), (<function _argnums_partial at 0x7f6e754bbbe0>, ((2,), (<jax._src.api_util._HashableWithStrictTypeEquality object at 0x7f6d9c4d11b0>, <jax._src.api_util._HashableWithStrictTypeEquality object at 0x7f6d9c4d17b0>, <jax._src.api_util._HashableWithStrictTypeEquality object at 0x7f6d9c4d1180>))), (<function result_paths at 0x7f6e75514940>, ())), (), None, ((ShapedArray(float32[25]), ShapedArray(float32[])), TracingDebugInfo(traced_for='jit', func_src_info='_while_loop_scan at /venv/lib/python3.10/site-packages/jaxopt/_src/loop.py:21', arg_names=('init_val[0]', 'init_val[1]'), result_paths=None), <hashable with closure=()>), False, None, ((), (), False, 'allow', None, False, 'standard', None, True, False, False, None))] <dict 140108750487680> is referred to by <dict 140112385575488>[<weakref at 0x7f6d9c49a160; to 'function' at 0x7f6e71d9e3b0 (_while_loop_scan)>] <dict 140112385575488> is referred to by <WeakKeyDictionary 140112385358528>.data <WeakKeyDictionary 140112385358528> is referred to by <method 140112385576768> <method 140112385576768> is referred to by <function 140112385515600>.cache_clear <function 140112385515600> is referred to by jax._src.pjit._create_pjit_jaxpr

Do you have any suggestion on how to avoid such problem?

btaba commented 8 months ago

Hi @Wedthree , it's likely because jaxopt is jitting something, but you aren't running jitted functions. I would do:

jit_reset = jax.jit(humanoid_env.reset)
jit_step = jax.jit(humanoid_env.step)
state = jit_reset(jax.random.PRNGKey(0))
action = jnp.zeros((humanoid_env.action_size,))
for i in range(10):
  state = jit_step(state, action)