pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.31k stars 246 forks source link

log1m_exp and log_diff_exp functions #1368

Open dylanhmorris opened 2 years ago

dylanhmorris commented 2 years ago

When writing custom distributions, it is often helpful to have numerically stable implementations of log_diff_exp(a, b) := log(exp(a) - exp(b)) and particularly log1m_exp(x) := log(1 - exp(x)). The naive implementations are not stable for many probabilistic programming use cases, and so probabilistic programming languages including Stan and PyMC provide numerically-stable implementations (typically following Machler, 2012) of these functions.

As far as I can tell, Numpyro does not, and they are not present in Jax.

I wonder whether it would be worth providing them. I have written basic implementations following Machler for my own use. I would happily make a PR including them, but someone more experienced could probably write better/more idiomatic ones.

import jax.numpy as jnp

def log1m_exp(x):
    """
    Numerically stable calculation
    of the quantity log(1 - exp(x)),
    following the algorithm of
    Machler [1]. This is
    the algorithm used in TensorFlow Probability,
    PyMC, and Stan, but it is not provided
    yet with Numpyro.

    Currently returns NaN for x > 0,
    but may be modified in the future
    to throw a ValueError

    [1] https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
    """
    # return 0. rather than -0. if
    # we get a negative exponent that exceeds
    # the floating point representation
    arr_x = 1.0 * jnp.array(x)
    oob = arr_x < jnp.log(jnp.finfo(
        arr_x.dtype).smallest_normal)
    mask = arr_x > -0.6931472  # appox -log(2)
    more_val = jnp.log(-jnp.expm1(arr_x))
    less_val = jnp.log1p(-jnp.exp(arr_x))

    return jnp.where(
        oob,
        0.,
        jnp.where(
            mask,
            more_val,
            less_val))

def log_diff_exp(a, b):
    # note that following Stan,
    # we want the log diff exp
    # of -inf, -inf to be -inf,
    # not nan, because that
    # corresponds to log(0 - 0) = -inf
    mask = a > b
    masktwo = (a == b) & (a < jnp.inf)
    return jnp.where(mask,
                     1.0 * a + log1m_exp(
                         1.0 * b - 1.0 * a),
                     jnp.where(masktwo,
                               -jnp.inf,
                               jnp.nan))
fehiepsi commented 2 years ago

I think it is a nice approach. We have those computations across various places. I think you can put those utilities in distributions/util.py file. The implementation looks reasonable to me. How about to discuss the details in your PR?

dylanhmorris commented 2 years ago

Sounds good. Will prepare one as soon as I have a chance.