Closed pierreglaser closed 10 months ago
Closes #1684
@fehiepsi I changed the reproducer to a pattern that felt more prone to happen in practice than https://github.com/pyro-ppl/numpyro/issues/1684#issuecomment-1823149622:
v_dist = jax.vmap( lambda loc, scale: dist.Normal(loc=loc, scale=scale, validate_args=True), in_axes=(0, 0), )(jnp.zeros((2,)), jnp.zeros((2,)))
Nice fix! Thanks, @pierreglaser!
Closes #1684
@fehiepsi I changed the reproducer to a pattern that felt more prone to happen in practice than https://github.com/pyro-ppl/numpyro/issues/1684#issuecomment-1823149622: