google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

[FR] add betaincinv #2399

Open tuchandra opened 4 years ago

tuchandra commented 4 years ago

Summary

Requesting the betaincinv function be added to jax.scipy.special. This is the inverse of the regularized incomplete beta function, or the PPF / quantile function for the Beta distribution.

Use case

I'm coming to JAX from numpyro, in which I'm trying to implement Gaussian copula models (example here). I'm looking for models with Beta marginals, which requires the inverse CDF / quantile function for the Beta distribution. After betainc was added in #1998, we have the CDF, but not the quantile function.

Having both the betainc and betaincinv functions would allow the addition of CDF and PPF methods to jax.scipy.stats.beta, much like the norm class and the scipy implementation. The betaincinv in particular would allow one to transform correlated uniform samples into correlated Beta samples (but the details are unrelated).

Other thoughts

I have no idea how much work this would take. On one hand, PR #1998 didn't seem like a hugely complex change (though still far above my understanding of JAX, so above all thank you for your hard work). On the other, the SciPy implementation of betainvinc is nasty and I don't know has to be re-implemented by hand or what.

mattjj commented 4 years ago

Thanks for raising this! Seems like a great idea. cc @neerajprad @fehiepsi

There are two ways to add this to JAX:

  1. Write it in pure Python in terms of the functions in jax/lax/lax.py, like we do most NumPy functions in jax/numpy/lax_numpy.py.
  2. Add it to the XLA client math library, like RegularizedIncompleteBeta was added, introduce a new lax primitive to call it, and define transformation rules for that primitive.

PR #1998 took the latter approach, but IIUC mainly because the author works on TFP and so wanted to use this code from both TF and JAX. Putting it in the XLA client library is a way to do that, since it's shared. But that's the only advantage to writing it in the C++ XLA client library: if you don't care about sharing it with TF then it seems much easier to write in pure Python in JAX (and potentially not have to write any transformation rules for it either).

@srvasude is this something you already plan on adding to the XLA client library for TFP?

srvasude commented 4 years ago

Hi! Yep I am planning on adding to both TF and XLA the betaincinv function. I should be getting to this hopefully soon (I have a proto-type written in TF C++, so I should be sending that out shortly. Translating to XLA will take a bit of time, but I'll have an implementation to compare against :) ).

In case you want to get something running sooner, I might recommend using Kumaraswamy marginals: https://en.wikipedia.org/wiki/Kumaraswamy_distribution. Kumaraswamy CDFs / Inverse CDFs are very easy to compute, and it has very similar shapes to a Beta distribution.

mattjj commented 4 years ago

Awesome, thanks for the expert tips @srvasude !

tuchandra commented 4 years ago

That's great to hear, thanks so much for the replies!

derekpowell commented 3 years ago

Just commenting to say I would also really value this feature and, though you are all most likely already aware, to point out that Stan has gradients implemented for their inc_beta function. I believe this is the implementation.

hawkinsp commented 3 years ago

@srvasude Did you end up adding this anywhere, e.g, to TFP?

dimarkov commented 3 years ago

I found this pure python implementation of betaincinv here https://malishoaib.wordpress.com/2014/05/30/inverse-of-incomplete-beta-function-computational-statisticians-wet-dream/ and adapted it to jax jit compiler. Here is the code if someone wants to use it, as I am not that proficient in github pull request, etc.

import jax.numpy as jnp
from jax.scipy.special import betaln, betainc
from jax import jit

@jit
def update_x(x, a, b, p, a1, b1, afac):
    err = betainc(a, b, x) - p
    t = jnp.exp(a1 * jnp.log(x) + b1 * jnp.log(1.0 - x) + afac)
    u = err/t
    tmp = u * (a1 / x - b1 / (1.0 - x))
    t = u/(1.0 - 0.5 * jnp.clip(tmp, a_max=1.0))
    x -= t
    x = jnp.where(x <= 0., 0.5 * (x + t), x)
    x = jnp.where(x >= 1., 0.5 * (x + t + 1.), x)

    return x, t

@jit
def func_1(a, b, p):
    pp = jnp.where(p < .5, p, 1. - p)
    t = jnp.sqrt(-2. * jnp.log(pp))
    x = (2.30753 + t * 0.27061) / (1.0 + t * (0.99229 + t * 0.04481)) - t
    x = jnp.where(p < .5, -x, x)
    al = (jnp.power(x, 2) - 3.0) / 6.0
    h = 2.0 / (1.0 / (2.0 * a - 1.0) + 1.0 / (2.0 * b - 1.0))
    w = (x * jnp.sqrt(al + h) / h)-(1.0 / (2.0 * b - 1) - 1.0/(2.0 * a - 1.0)) * (al + 5.0 / 6.0 - 2.0 / (3.0 * h))
    return a / (a + b * jnp.exp(2.0 * w))

@jit
def func_2(a, b, p):
    lna = jnp.log(a / (a + b))
    lnb = jnp.log(b / (a + b))
    t = jnp.exp(a * lna) / a
    u = jnp.exp(b * lnb) / b
    w = t + u

    return jnp.where(p < t/w, jnp.power(a * w * p, 1.0 / a), 1. - jnp.power(b *w * (1.0 - p), 1.0/b))

@jit
def compute_x(p, a, b):
    return jnp.where(jnp.logical_and(a >= 1.0, b >= 1.0), func_1(a, b, p), func_2(a, b, p))

@jit
def betaincinv(a, b, p):
    a1 = a - 1.0
    b1 = b - 1.0

    ERROR = 1e-8

    p = jnp.clip(p, a_min=0., a_max=1.)

    x = jnp.where(jnp.logical_or(p <= 0.0, p >= 1.), p, compute_x(p, a, b))

    afac = - betaln(a, b)
    stop  = jnp.logical_or(x == 0.0, x == 1.0)
    for i in range(10):
        x_new, t = update_x(x, a, b, p, a1, b1, afac)
        x = jnp.where(stop, x, x_new)
        stop = jnp.where(jnp.logical_or(jnp.abs(t) < ERROR * x, stop), True, False)

    return x
dimarkov commented 3 years ago

Comparing the output with betaincinv from scipy gives good matching although there is a bit of deviations at some points of the order of 1e-5, which is not an issue for my application. You can do a quick test as follows:

from scipy.special import betaincinv as scp_betaincinv
import matplotlib.pyplot as plt

a = 1/random.gamma(random.PRNGKey(0), 1., shape=(10000,))
b = 1/random.gamma(random.PRNGKey(1), 1., shape=(10000,))
p = random.beta(random.PRNGKey(0), 1., 1., shape=(10000,))

z1 = betaincinv(a, b, p)
z2 = scp_betaincinv(a, b, p)

notequal = ~jnp.isclose(z1, z2)

plt.hist(z1[notequal]- z2[notequal])

However, comparing the execution time with

%timeit betaincinv(a, b, p)
%timeit scp_betaincinv(a, b, p)

I get the following values

57.2 ms ± 181 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
23.1 ms ± 28.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

which makes jit-ed function slower from the scipy implementation. Maybe a more experienced jax-er could give me some pointers on how to make the function faster.

hawkinsp commented 3 years ago

@dimarkov For your benchmark, you should stick a block_until_ready() on the JAX version, see https://jax.readthedocs.io/en/latest/async_dispatch.html . You should also be careful about whether you are comparing float32 or float64 implementations; note JAX defaults to disabling float64.

I doubt we can accept that particular contribution. It is a port of a port of code accompanying a book that requires an explicit license for distribution, if I'm reading things right.

That said, there are other implementations that we could port. One notable version is incbi from Cephes, which we have explicit permission to redistribute.

dimarkov commented 3 years ago

@hawkinsp Thanks. I tried out what you suggested but the execution time remains the same. I compared only float32 types. Porting the c code into jax is beyond my skill set and time resources. If people stumble on this post, and need a quick solution until the code is ported from c, feel free to use what I posted here.

Joshuaalbert commented 2 years ago

This feature will be relevant to JAXNS as well, where we sample from beta distibutions many many times, and currently that relies on two while_loop's per sample.

leandrolcampos commented 2 years ago

Nightly builds of TensorFlow Probability now include tfp.math.betaincinv. Here's an example:

from jax import jit, vmap
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
import matplotlib.pyplot as plt

a = 0.5
b = 4.
y = jnp.linspace(0, 1, 1000)
plt.plot(y, vmap(jit(lambda z: tfp.math.betaincinv(a, b, z)))(y))

which produces:

betaincinv

You can also compute the gradients of tfp.math.betaincinv with respect to all its parameters:

from jax import grad

a = 0.5
b = 4.
y = 0.1
grad(tfp.math.betaincinv, argnums=[0, 1, 2])(a, b, y)

which produces:

(DeviceArray(0.02030255, dtype=float32, weak_type=True),
 DeviceArray(-0.00055659, dtype=float32, weak_type=True),
 DeviceArray(0.04214846, dtype=float32, weak_type=True))

Finally, when comparing tfp.math.betaincinv with scipy.special.betaincinv, note that SciPy only provides the double precision implementation of this function (for example, if the dtype of the arguments is float32, then they are cast to float64 and the output is cast back to float32). See the SciPy/Cephes implementation here.

Qazalbash commented 9 months ago

Dear JAX Development Team,

I am currently developing modules focused on inverse transform sampling, with a specific application in astrophysics. Our project relies heavily on JAX. However, we have encountered a limitation: JAX does not currently support the incomplete inverse beta function, which is critical for our work.

Could you please advise on any existing alternatives within JAX that provide JIT compilation and maintain high precision? Alternatively, do you know if there is a plan to implement this function soon?

Thank you for your assistance and the continuous development of this invaluable tool.

fehiepsi commented 9 months ago

@Qazalbash see the comment https://github.com/google/jax/issues/2399#issuecomment-1225990206 - I think tfp.math.betaincinv should work.

Qazalbash commented 9 months ago

@fehiepsi Thanks it worked, now it requires installing another package, it would be great if JAX provides it too.