world-modelz / dreamax

A scalable Dreamer implementation in JAX
MIT License
11 stars 2 forks source link

Bug: A function transformed by JAX had a side effect ... #14

Closed andreaskoepf closed 2 years ago

andreaskoepf commented 2 years ago

train.py terminates with the following error message:

jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (400,) and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was update at /dreamax/dreamer/dreamer.py:140 traced for jit.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.8/dist-packages/jmp/_src/loss_scale.py:185 (select_tree). 

Full stack trace:

Traceback (most recent call last):
  File "train.py", line 336, in <module>
    main()
  File "train.py", line 316, in main
    master_params = jax.tree_map(lambda x: jax.device_put(x, device=jax.devices('cpu')[0]), train_agent.params)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "train.py", line 316, in <lambda>
    master_params = jax.tree_map(lambda x: jax.device_put(x, device=jax.devices('cpu')[0]), train_agent.params)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/api.py", line 2667, in device_put
    return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/tree_util.py", line 184, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/usr/local/lib/python3.8/dist-packages/jax/_src/api.py", line 2667, in <lambda>
    return tree_map(lambda y: dispatch.device_put_p.bind(y, device=device), x)
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 323, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 981, in find_top_trace
    top_tracer._assert_live()
  File "/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py", line 1302, in _assert_live
    raise core.escaped_tracer_error(self, None)
jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with shape (400,) and dtype float32 to escape.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was update at /dreamax/dreamer/dreamer.py:140 traced for jit.
------------------------------
The leaked intermediate value was created on line /usr/local/lib/python3.8/dist-packages/jmp/_src/loss_scale.py:185 (select_tree). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
train.py:310 (main)
/dreamax/dreamer/dreamer.py:153 (update)
/dreamax/dreamer/dreamer.py:242 (update_actor)
/dreamax/dreamer/dreamer.py:288 (grad_step)
/usr/local/lib/python3.8/dist-packages/jmp/_src/loss_scale.py:185 (select_tree)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError
andreaskoepf commented 2 years ago

see https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

andreaskoepf commented 2 years ago

I could reproduce the error on two independent Unbuntu 20.04 machines with nvidia-driver Version: 510.60.02 / CUDA: 11.6 and jax version 0.3.7.

XMaster96 commented 2 years ago

close with 1ff8933662a414dce6e42d2d1267e98b0f0f92ec / https://github.com/world-modelz/dreamax/commit/40c3966f90fc749a7b0ca0b94458905530997f83