google-deepmind / distrax

Apache License 2.0
535 stars 32 forks source link

Leaked tracers when jitting `MultiVariateNormalDiag` distribution #230

Open keraJLi opened 1 year ago

keraJLi commented 1 year ago

I would like to return a MultiVariateNormalDiag distribution from a jitted function. However, I'm getting a leaked tracer. I've created the following minimal example

import jax
import distrax
import jax.numpy as jnp

jax.config.update("jax_check_tracer_leaks", True)

jitted_cat = jax.jit(lambda: distrax.Categorical(logits=jnp.zeros(1)))
jitted_mvn = jax.jit(lambda: distrax.MultivariateNormalDiag(loc=jnp.zeros(1)))

print("Creating jitted categorical distribution")
jitted_cat()
print("Creating jitted multivariate normal distribution")
jitted_mvn()

which outputs

Creating jitted categorical distribution
Creating jitted multivariate normal distribution
Traceback (most recent call last):
  File "test_mvn.py", line 13, in <module>
    jitted_mvn()
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 235, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
    args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/api.py", line 440, in infer_params
    return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 513, in common_infer_params
    jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 965, in _pjit_jaxpr
    jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
    ans = call(fun, *args)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 923, in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2031, in trace_to_jaxpr_dynamic
    with core.new_main(DynamicJaxprTrace, dynamic=True) as main:  # type: ignore
  File "/home/me/anaconda3/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
  File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/core.py", line 1083, in new_main
    if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728702848> is referred to by <ScalarAffine 140511566864992>._scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728705008> is referred to by <ScalarAffine 140511566864992>._inv_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:7 (<lambda>)
<DynamicJaxprTracer 140510728705328> is referred to by <ScalarAffine 140511566864992>._log_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/me/test_mvn.py", line 13, in <module>
    jitted_mvn()
  File "/home/me/anaconda3/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728702848> is referred to by <ScalarAffine 140511566864992>._scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728705008> is referred to by <ScalarAffine 140511566864992>._inv_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:7 (<lambda>)
<DynamicJaxprTracer 140510728705328> is referred to by <ScalarAffine 140511566864992>._log_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]

As seen from the output, returning a Categorical works fine, but returning a MultivariateNormalDiag results in leaked tracers. This seems like a bug. I'm using

distrax==0.1.3
jax==0.4.5
jaxlib==0.4.4+cuda11.cudnn82