Open IlayMenahem opened 23 hours ago
Thanks for the report! The issue here occurs when loc > x
in the gamma logpdf
. This is checked here:
so this will return -inf
in that case as expected, but the NaN pops up here:
It's safe to set jax_debug_nans
to False
in this case (since it's checked later), or make the following workaround change to your code:
- theta = data[:, :, 2]
+ theta = jax.lax.clamp(-jnp.inf, data[:, :, 2], target[:, None])
But I think we should definitely fix these leaking NaNs in JAX itself! If you're keen to submit a PR, I'd be happy to help/point you in the right direction, otherwise I can probably fix it myself soon.
Description
i got a FloatingPointError when using jax.scipy.stats.gamma.pdf, i've tried using jax.config.update("jax_enable_x64", True) to no avail.
code for reproduction
the error
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35 jaxlib: 0.4.35 numpy: 2.1.3 python: 3.12.1 (main, Sep 30 2024, 17:05:21) [GCC 9.4.0] device info: cpu-1, 1 local devices" process_count: 1 platform: uname_result(system='Linux', node='codespaces-5a7c09', release='6.5.0-1025-azure', version='#26~22.04.1-Ubuntu SMP Thu Jul 11 22:33:04 UTC 2024', machine='x86_64')