pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.15k stars 236 forks source link

Fix faulty interaction between `jax.vmap` and `validate_args=True` #1686

Closed pierreglaser closed 10 months ago

pierreglaser commented 10 months ago

Closes #1684

@fehiepsi I changed the reproducer to a pattern that felt more prone to happen in practice than https://github.com/pyro-ppl/numpyro/issues/1684#issuecomment-1823149622:

    v_dist = jax.vmap(
        lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=True),
        in_axes=(0, 0),
    )(jnp.zeros((2,)), jnp.zeros((2,)))
fehiepsi commented 10 months ago

Nice fix! Thanks, @pierreglaser!