train.py terminates with the following error message:
jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (400,) and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was update at /dreamax/dreamer/dreamer.py:140 traced for jit.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.8/dist-packages/jmp/_src/loss_scale.py:185 (select_tree).
Full stack trace:
Traceback (most recent call last):
File "train.py", line 336, in <module>
main()
File "train.py", line 316, in main
master_params = jax.tree_map(lambda x: jax.device_put(x, device=jax.devices('cpu')[0]), train_agent.params)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "train.py", line 316, in <lambda>
master_params = jax.tree_map(lambda x: jax.device_put(x, device=jax.devices('cpu')[0]), train_agent.params)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/api.py", line 2667, in device_put
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in tree_map
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in <genexpr>
return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File "/usr/local/lib/python3.8/dist-packages/jax/_src/api.py", line 2667, in <lambda>
return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 323, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 981, in find_top_trace
top_tracer._assert_live()
File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1302, in _assert_live
raise core.escaped_tracer_error(self, None)
jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (400,) and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was update at /dreamax/dreamer/dreamer.py:140 traced for jit.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.8/dist-packages/jmp/_src/loss_scale.py:185 (select_tree).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
train.py:310 (main)
/dreamax/dreamer/dreamer.py:153 (update)
/dreamax/dreamer/dreamer.py:242 (update_actor)
/dreamax/dreamer/dreamer.py:288 (grad_step)
/usr/local/lib/python3.8/dist-packages/jmp/_src/loss_scale.py:185 (select_tree)
------------------------------
To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
train.py
terminates with the following error message:Full stack trace: