Closed ordabayevy closed 1 year ago
Hi, I am receiving the same error when running Stochastic Variational Inference on a model using TraceGraphELBO
as the loss.
Stack Trace is here -- https://gist.github.com/pankajb64/16b0ff30c380071b0f5613577deaff61
Looking forward to PR #1543
Thanks, @pankajb64! I would recommend to set os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'
at the top of your problem at the moment. There are some weird issues with the solution in the PR.
I'm getting the error message below when running
pytest test/contrib/test_enum_elbo.py
with the latest version of jax (0.4.4):