pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.19k stars 241 forks source link

`NameError: unbound axis name: _provenance` with the latest jax version #1542

Closed ordabayevy closed 1 year ago

ordabayevy commented 1 year ago

I'm getting the error message below when running pytest test/contrib/test_enum_elbo.py with the latest version of jax (0.4.4):

self = Traced<ShapedArray(int32[1;_provenance:frozenset({'w'})])>with<DynamicJaxprTrace(level=4/0)>, other = 0

    def deferring_binary_op(self, other):
      if hasattr(other, '__jax_array__'):
        other = other.__jax_array__()
      args = (other, self) if swap else (self, other)
      if isinstance(other, _accepted_binop_types):
>       return binary_op(*args)
E       NameError: unbound axis name: _provenance. The following axis names (e.g. defined by pmap) are available to collective operations: []

/home/yordabay/anaconda3/envs/numpyro/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:5071: NameError
pankajb64 commented 1 year ago

Hi, I am receiving the same error when running Stochastic Variational Inference on a model using TraceGraphELBO as the loss.

Stack Trace is here -- https://gist.github.com/pankajb64/16b0ff30c380071b0f5613577deaff61

Looking forward to PR #1543

fehiepsi commented 1 year ago

Thanks, @pankajb64! I would recommend to set os.environ['JAX_JIT_PJIT_API_MERGE'] = '0' at the top of your problem at the moment. There are some weird issues with the solution in the PR.