I would like to return a MultiVariateNormalDiag distribution from a jitted function. However, I'm getting a leaked tracer. I've created the following minimal example
Creating jitted categorical distribution
Creating jitted multivariate normal distribution
Traceback (most recent call last):
File "test_mvn.py", line 13, in <module>
jitted_mvn()
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 163, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 235, in cache_miss
outs, out_flat, out_tree, args_flat = _python_pjit_helper(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 179, in _python_pjit_helper
args_flat, _, params, in_tree, out_tree, _ = infer_params_fn(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/api.py", line 440, in infer_params
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 513, in common_infer_params
jaxpr, consts, canonicalized_out_shardings_flat = _pjit_jaxpr(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 965, in _pjit_jaxpr
jaxpr, final_consts, global_out_avals = _create_pjit_jaxpr(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/linear_util.py", line 301, in memoized_fun
ans = call(fun, *args)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/pjit.py", line 923, in _create_pjit_jaxpr
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
return func(*args, **kwargs)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2031, in trace_to_jaxpr_dynamic
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
File "/home/me/anaconda3/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
next(self.gen)
File "/home/me/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/core.py", line 1083, in new_main
if leaked_tracers: raise leaked_tracer_error("trace", t(), leaked_tracers)
jax._src.traceback_util.UnfilteredStackTrace: Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728702848> is referred to by <ScalarAffine 140511566864992>._scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728705008> is referred to by <ScalarAffine 140511566864992>._inv_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:7 (<lambda>)
<DynamicJaxprTracer 140510728705328> is referred to by <ScalarAffine 140511566864992>._log_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/me/test_mvn.py", line 13, in <module>
jitted_mvn()
File "/home/me/anaconda3/envs/jax/lib/python3.10/contextlib.py", line 142, in __exit__
next(self.gen)
Exception: Leaked trace MainTrace(1,DynamicJaxprTrace). Leaked tracer(s):
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728702848> is referred to by <ScalarAffine 140511566864992>._scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:8 (<lambda>)
<DynamicJaxprTracer 140510728705008> is referred to by <ScalarAffine 140511566864992>._inv_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=1/0)>
This DynamicJaxprTracer was created on line /home/me/test_mvn.py:7 (<lambda>)
<DynamicJaxprTracer 140510728705328> is referred to by <ScalarAffine 140511566864992>._log_scale
<ScalarAffine 140511566864992> is referred to by <Block 140511566865664>._bijector
<Block 140511566865664> is referred to by <method 140511570285120>
<method 140511570285120> is referred to by <list 140510728761984>[14]
<list 140510728761984> is referred to by <tuple 140510728754816>[0]
As seen from the output, returning a Categorical works fine, but returning a MultivariateNormalDiag results in leaked tracers. This seems like a bug. I'm using
I would like to return a
MultiVariateNormalDiag
distribution from a jitted function. However, I'm getting a leaked tracer. I've created the following minimal examplewhich outputs
As seen from the output, returning a
Categorical
works fine, but returning aMultivariateNormalDiag
results in leaked tracers. This seems like a bug. I'm using