pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 235 forks source link

Using deprecated `jax.core.safe_map` in `ops/provenance.py` #1733

Closed GaetanLepage closed 7 months ago

GaetanLepage commented 7 months ago

There is a reference to a now (since jax 0.4.24) deprecated funcion: jax.core.safe_map.

This leads to the failure of test test_beta_bernoulli when ran with jax==0.4.24:

copying numpyro/version.py -> build/lib/numpyro
...skipping...
    svi_state, loss = jit(body_fn)(svi_state, None)
numpyro/infer/svi.py:353: in body_fn
    svi_state, loss = self.update(svi_state, *args, **kwargs)
numpyro/infer/svi.py:266: in update
    (loss_val, mutable_state), optim_state = self.optim.eval_and_update(
numpyro/optim.py:80: in eval_and_update
    (out, aux), grads = value_and_grad(fn, has_aux=True)(params)
numpyro/infer/svi.py:61: in loss_fn
    elbo.loss(
numpyro/infer/elbo.py:1154: in loss
    return -single_particle_elbo(rng_key)
numpyro/infer/elbo.py:1003: in single_particle_elbo
    model_deps, guide_deps = get_nonreparam_deps(
numpyro/infer/elbo.py:689: in get_nonreparam_deps
    model_deps, guide_deps = eval_provenance(fn, **latents)
numpyro/ops/provenance.py:43: in eval_provenance
    avals = core.safe_map(shaped_abstractify, args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

name = 'safe_map'

    def getattr(name):
      if name in deprecations:
        message, fn = deprecations[name]
        if fn is None:
          raise AttributeError(message)
        warnings.warn(message, DeprecationWarning, stacklevel=2)
        return fn
>     raise AttributeError(f"module {module!r} has no attribute {name!r}")
E     AttributeError: module 'jax.core' has no attribute 'safe_map'
E     --------------------
E     For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include t>

/nix/store/yypc2kkpr8i4ri98k9rwdlx9v39ap9n3-python3.11-jax-0.4.24/lib/python3.11/site-packages/jax/_src/deprecations.py:53: AttributeError
----------------------------- Captured stderr call -----------------------------
  0%|          | 0/3 [00:00<?, ?it/s]
=========================== short test summary info ============================
FAILED test/test_pickle.py::test_beta_bernoulli - AttributeError: module 'jax.core' has no attribute 'safe_map'
===== 1 failed, 421 passed, 16 skipped, 56 deselected in 158.64s (0:02:38) =====
fehiepsi commented 7 months ago

It seems that we need to use jax.util.safe_map there https://github.com/google/jax/blob/main/jax/util.py#L22C3-L22C11. Do you want to submit the fix?

GaetanLepage commented 7 months ago

I think that it has already been fixed in #1664. I guess that we just have to wait for the next release then !

fehiepsi commented 7 months ago

Oh, I will make a release soon to unblock the issue.

GaetanLepage commented 7 months ago

Oh, I will make a release soon to unblock the issue.

Thanks ! That would definitely help :)

GaetanLepage commented 7 months ago

Fixed by #1664

GaetanLepage commented 7 months ago

Oh, I will make a release soon to unblock the issue.

Hi ! Is is still planned to have a release including this fix ?

fehiepsi commented 7 months ago

Yes, we are going to make a release this week.