jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.53k stars 2.8k forks source link

FloatingPointError in jax.scipy.stats #24939

Open IlayMenahem opened 23 hours ago

IlayMenahem commented 23 hours ago

Description

i got a FloatingPointError when using jax.scipy.stats.gamma.pdf, i've tried using jax.config.update("jax_enable_x64", True) to no avail.

code for reproduction

import jax
import jax.numpy as jnp

jax.config.update("jax_debug_nans", True)

def mle_loss(data, target):
    '''
    data: (batch_size, history_length, bar_length)
    target: (batch_size, 1)
    '''
    weights = data[:, :, 0]
    k = data[:, :, 1]
    theta = data[:, :, 2]

    probs = jax.vmap(jax.scipy.stats.gamma.pdf)(target, k, theta)
    prob = jnp.sum(weights * probs, axis=1)

    loss = -jax.lax.log(prob).mean()
    loss = jax.lax.clamp(-1e6, loss, 1e6)

    return loss   

data = jnp.array([[[1.        , 0.7096346 , 0.7514472 ]],

       [[1.        , 0.7194072 , 0.735364  ]],

       [[1.        , 0.7475644 , 0.7523259 ]],

       [[1.        , 0.7042354 , 0.7264852 ]],

       [[1.        , 0.47818542, 1.3681346 ]],

       [[1.        , 0.6943242 , 0.7313199 ]],

       [[1.        , 0.687601  , 0.81750506]],

       [[1.        , 0.72067565, 0.7784166 ]]])
target = jnp.array([1.0, 1.0004145,  0.99964607, 1.0004367,  1.000424,   0.99967706,
 1.0009085  ,1.000621])

mle_loss(data, target)

the error

{
    "name": "FloatingPointError",
    "message": "invalid value (nan) encountered in jit(log). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/jax-ml/jax.",
    "stack": "---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
    [... skipping hidden 1 frame]

File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/profiler.py:333, in annotate_function.<locals>.wrapper(*args, **kwargs)
    332 with TraceAnnotation(name, **decorator_kwargs):
--> 333   return func(*args, **kwargs)
    334 return wrapper

File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py:1292, in ExecuteReplicated.__call__(self, *args)
   1291 for arrays in out_arrays:
-> 1292   dispatch.check_special(self.name, arrays)
   1293 out = self.out_handler(out_arrays)

File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/dispatch.py:327, in check_special(name, bufs)
    326 for buf in bufs:
--> 327   _check_special(name, buf.dtype, buf)

File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/dispatch.py:332, in _check_special(name, dtype, buf)
    331 if config.debug_nans.value and np.any(np.isnan(np.asarray(buf))):
--> 332   raise FloatingPointError(f\"invalid value (nan) encountered in {name}\")
    333 if config.debug_infs.value and np.any(np.isinf(np.asarray(buf))):

FloatingPointError: invalid value (nan) encountered in jit(log)

During handling of the above exception, another exception occurred:

FloatingPointError                        Traceback (most recent call last)
Cell In[1], line 41
     23 data = jnp.array([[[1.        , 0.7096346 , 0.7514472 ]],
     24 
     25        [[1.        , 0.7194072 , 0.735364  ]],
   (...)
     36 
     37        [[1.        , 0.72067565, 0.7784166 ]]])
     38 target = jnp.array([1.0, 1.0004145,  0.99964607, 1.0004367,  1.000424,   0.99967706,
     39  1.0009085  ,1.000621])
---> 41 mle_loss(data, target)

Cell In[1], line 15, in mle_loss(data, target)
     12 k = data[:, :, 1]
     13 theta = data[:, :, 2]
---> 15 probs = jax.vmap(jax.scipy.stats.gamma.pdf)(target, k, theta)
     16 prob = jnp.sum(weights * probs, axis=1)
     18 loss = -jax.lax.log(prob).mean()

    [... skipping hidden 3 frame]

File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/scipy/stats/gamma.py:92, in pdf(x, a, loc, scale)
     62 def pdf(x: ArrayLike, a: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
     63   r\"\"\"Gamma probability distribution function.
     64 
     65   JAX implementation of :obj:`scipy.stats.gamma` ``pdf``.
   (...)
     90     - :func:`jax.scipy.stats.gamma.logsf`
     91   \"\"\"
---> 92   return lax.exp(logpdf(x, a, loc, scale))

File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/scipy/stats/gamma.py:56, in logpdf(x, a, loc, scale)
     54 one = _lax_const(x, 1)
     55 y = lax.div(lax.sub(x, loc), scale)
---> 56 log_linear_term = lax.sub(xlogy(lax.sub(a, one), y), y)
     57 shape_terms = lax.add(gammaln(a), lax.log(scale))
     58 log_probs = lax.sub(log_linear_term, shape_terms)

    [... skipping hidden 7 frame]

File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/scipy/special.py:498, in xlogy(x, y)
    496 x, y = promote_args_inexact(\"xlogy\", x, y)
    497 x_ok = x != 0.
--> 498 return jnp.where(x_ok, lax.mul(x, lax.log(y)), jnp.zeros_like(x))

    [... skipping hidden 17 frame]

File /workspaces/options/venv/lib/python3.12/site-packages/jax/_src/pjit.py:1692, in _pjit_call_impl_python(jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, *args)
   1675 # If control reaches this line, we got a NaN on the output of `compiled`
   1676 # but not `fun.call_wrapped` on the same arguments. Let's tell the user.
   1677 msg = (f\"{str(e)}. Because \"
   1678        \"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the \"
   1679        \"de-optimized function (i.e., the function as if the `jit` \"
   (...)
   1690        \"If you see this error, consider opening a bug report at \"
   1691        \"https://github.com/jax-ml/jax.\")
-> 1692 raise FloatingPointError(msg)

FloatingPointError: invalid value (nan) encountered in jit(log). Because jax_config.debug_nans.value and/or config.jax_debug_infs is set, the de-optimized function (i.e., the function as if the `jit` decorator were removed) was called in an attempt to get a more precise error message. However, the de-optimized function did not produce invalid values during its execution. This behavior can result from `jit` optimizations causing the invalid value to be produced. It may also arise from having nan/inf constants as outputs, like `jax.jit(lambda ...: jax.numpy.nan)(...)`. 

It may be possible to avoid the invalid value by removing the `jit` decorator, at the cost of losing optimizations. 

If you see this error, consider opening a bug report at https://github.com/jax-ml/jax."
}

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.35 jaxlib: 0.4.35 numpy: 2.1.3 python: 3.12.1 (main, Sep 30 2024, 17:05:21) [GCC 9.4.0] device info: cpu-1, 1 local devices" process_count: 1 platform: uname_result(system='Linux', node='codespaces-5a7c09', release='6.5.0-1025-azure', version='#26~22.04.1-Ubuntu SMP Thu Jul 11 22:33:04 UTC 2024', machine='x86_64')

dfm commented 5 hours ago

Thanks for the report! The issue here occurs when loc > x in the gamma logpdf. This is checked here:

https://github.com/jax-ml/jax/blob/afdc79271cec44cd86f83419342014d484609ca4/jax/_src/scipy/stats/gamma.py#L59

so this will return -inf in that case as expected, but the NaN pops up here:

https://github.com/jax-ml/jax/blob/afdc79271cec44cd86f83419342014d484609ca4/jax/_src/scipy/stats/gamma.py#L56

It's safe to set jax_debug_nans to False in this case (since it's checked later), or make the following workaround change to your code:

-     theta = data[:, :, 2]
+     theta = jax.lax.clamp(-jnp.inf, data[:, :, 2], target[:, None])

But I think we should definitely fix these leaking NaNs in JAX itself! If you're keen to submit a PR, I'd be happy to help/point you in the right direction, otherwise I can probably fix it myself soon.