google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.76k stars 2.72k forks source link

Add scipy.special.hyp2f1? #2991

Open jmsull opened 4 years ago

jmsull commented 4 years ago

It would be great if hyp2f1 could be added under jax.scipy.special! Are there any plans to do this?

My application uses a special case of the hyp2f1 that seems like it can be written using special functions already implemented in jax (gammaln and betainc, but it turns out that the scipy implementation of hyp2f1 uses a transformation on the input that prevents hyp2f1 from being written in terms of these other functions. As far as I can tell there is not another convenient substitution in terms of existing jax functions.

Thanks!

mattjj commented 4 years ago

Thanks for the feature request! cc @srvasude just to be in the loop about special functions.

We don't have any current plans to add this, but we love feature requests (and +1's on them) because they help us prioritize.

This might actually be a good first issue for contributors to take on. We'd basically just need to implement this in pure Python + NumPy, and rely on the XLA compiler (via jit) to generate a decently-optimized compiled function. Take a look at, for example, ndtr and ndtri in jax/scipy/special.py to see how that might look. (We also rely on XLA to compile functions for gammaln and betainc, but instead of writing those in Python and just using jit to build the XLA computation, the XLA program is built in C++ for the sole purpose of sharing its implementation with TensorFlow. But for that sharing consideration, performance-wise it's just as good to write things in pure JAX NumPy and use jit, and it's easier to develop that way too!)

joaogui1 commented 4 years ago

I think adding a link to a implementation of hyp2f1 (or a mathematical formula) using gammaln and betainc may help

jmsull commented 4 years ago

@mattjj Thanks for the info!

@joaogui1 The relationship between hyp2f1 and betainc + gammaln (specified on the jax documentation for betainc, also eqn. 3 here) is only valid for the special combination of arguments hyp2f1(a,1-b,a+1,x), and valid only when 1>x>0. My application requires handling of the case where -1<x<0. This negative value of the argument is implemented in scipy.special.hyp2f1 (and is mentioned here) via a transformation which changes the arguments of hyp2f1 such that when the first argument is a, the third argument is no longer a+1, so the ability to write hyp2f1 in terms of betainc + gammaln goes away.

A more general handling of hyp2f1 as scipy does is I think required.

joaogui1 commented 4 years ago

Hey @jmsull I started implementing this over the weekend (rather translating the implementation you linked to) but life got in the way and I haven't finished it yet, Are you still interested in a JAX implementation of this function? I think I should be done by Friday, modulo some debugging.

jmsull commented 4 years ago

@joaogui1

Still interested - great to hear you started working on it!

FrescoFlacko commented 4 years ago

What is the status of this issue? I want to start making contributions to this project but wondering if this issue is a good place to start.

joaogui1 commented 4 years ago

@FrescoFlacko life got in the way and I never finished it, but it's a rather large and tedious implementation, so I don't think it's a good first issue

milind-soni commented 3 years ago

@FrescoFlacko life got in the way and I never finished it, but it's a rather large and tedious implementation, so I don't think it's a good first issue

can you guide me how to go about it please.

joaogui1 commented 3 years ago

@milind-soni I looked around Scipy until I found the implementation and then proceeded to translate it to jax, guess that's the general idea (the implementation was probably written in C though, so you'll need some understanding of it to do a good port)

FloList commented 2 years ago

Not sure if there's still any interest in this topic, but the following function is a workaround that uses the method in John Pearson's MSc thesis based on a Taylor series approximation (see Sec. 4.2). The code below doesn't treat various corner cases such as when differences between a, b, and c are integer, when z = exp(i pi / 3), etc., but I thought I'll leave it here in case someone wants to build on it

@jax.jit
def hyp2f1(a, b, c, z, n_max_iter=500, tol=1e-10, debug_mode=False):
    """
    Hypergeometric function implemented in jax, based on John Pearson's MSc thesis
    Computation of Hypergeometric Functions, specifically taylora2f1
    :param a: a (1D jax device array)
    :param b: b (1D jax device array)
    :param c: c (1D jax device array)
    :param z: z (1D jax device array), either one value of z for each (a, b, c) or different z-values for a single (a, b, c)
    :param n_max_iter: maximum number of iterations
    :param tol: tolerance
    :param debug_mode: debugging mode, doesn't use jitted operations
    :return: Gauss hypergeometric function 1F2(a, b, c, z)
    """

    a, b, c, z = jnp.atleast_1d(a), jnp.atleast_1d(b), jnp.atleast_1d(c), jnp.atleast_1d(z)
    assert a.ndim == b.ndim == c.ndim == z.ndim == 1

    # If only a single value for a, b, c is provided, but multiple values of z: tile a, b, c
    if a.shape[0] == b.shape[0] == c.shape[0] == 1 and z.shape[0] > 1:
        a = a * jnp.ones_like(z)
        b = b * jnp.ones_like(z)
        c = c * jnp.ones_like(z)

    assert a.shape[0] == b.shape[0] == c.shape[0] == z.shape[0], f"{a.shape[0], b.shape[0], c.shape[0], z.shape[0]}"

    def compute_output(a_loc, b_loc, c_loc, z_loc):

        def cond_fun(val):
            step_, a_, b_, c_, z_, coeff_old_, coeff_new_, sum_new_, tol_, n_max_iter_ = val
            cond_steps = jnp.all(step_ < n_max_iter_)
            cond_tol = jnp.all(jnp.logical_or(jnp.abs(coeff_old_) / jnp.abs(sum_new_) > tol_,
                                              jnp.abs(coeff_new_) / jnp.abs(sum_new_) > tol_))
            return jnp.all(jnp.logical_and(cond_steps, cond_tol))

        def body_fun(val):
            step_, a_, b_, c_, z_, coeff_old_, coeff_new_, sum_new_, tol_, n_max_iter_ = val
            coeff_new_ = (a_ + step_ - 1.0) * (b_ + step_ - 1.0) / (c_ + step_ - 1.0) * z_ / step_ * coeff_old_
            sum_new_ += coeff_new_
            return step_ + 1, a_, b_, c_, z_, coeff_new_, coeff_new_, sum_new_, tol_, n_max_iter_

        init_val = (jnp.asarray(1.0, ), a_loc * 1.0, b_loc * 1.0, c_loc * 1.0, z_loc * 1.0, jnp.asarray(1.0, ),
                    a_loc * 1.0, jnp.asarray(1.0, ), jnp.asarray(tol), jnp.asarray(n_max_iter * 1.0))

        if debug_mode:
            def while_loop(cond_fun, body_fun, init_val):
                val = init_val
                while cond_fun(val):
                    val = body_fun(val)
                return val

        else:
            while_loop = jax.lax.while_loop

        final_step, _, _, _, _, _, _, sum_out, _, _ = while_loop(cond_fun, body_fun, init_val=init_val)
        return sum_out

    def gamma(val):
        return jnp.where(val >= 0, jnp.exp(gammaln(val)), -jnp.exp(gammaln(val)))

    # Set indices for the different cases of the transformation mapping z to within the radius rho < 0.5
    # where Taylor series converges fast
    cases = jnp.where(z < -1,
                      1, jnp.where(z < 0,
                                   2, jnp.where(z <= 0.5,
                                                3, jnp.where(z <= 1,
                                                             4, jnp.where(z <= 2,
                                                                          5, 6)))))

    cases -= 1  # 0-based indexing

    # Define the branches
    def branch_1(this_a, this_b, this_c, this_z):
        term_1 = (1 - this_z) ** (-this_a) * (
                gamma(this_c) * gamma(this_b - this_a) / gamma(this_b) / gamma(this_c - this_a)) \
                 * compute_output(this_a, this_c - this_b, this_a - this_b + 1.0, 1.0 / (1.0 - this_z))
        term_2 = (1 - this_z) ** (-this_b) * (
                gamma(this_c) * gamma(this_a - this_b) / gamma(this_a) / gamma(this_c - this_b)) \
                 * compute_output(this_b, this_c - this_a, this_b - this_a + 1.0, 1.0 / (1.0 - this_z))
        return term_1 + term_2

    def branch_2(this_a, this_b, this_c, this_z):
        return (1 - this_z) ** (-this_a) * compute_output(this_a, this_c - this_b, this_c,
                                                          this_z / (this_z - 1.0))

    def branch_3(this_a, this_b, this_c, this_z):
        return compute_output(this_a, this_b, this_c, this_z)

    def branch_4(this_a, this_b, this_c, this_z):
        term_1 = (gamma(this_c) * gamma(this_c - this_a - this_b)
                  / gamma(this_c - this_a) / gamma(this_c - this_b)) \
                 * compute_output(this_a, this_b, this_a + this_b - this_c + 1.0, 1.0 - this_z)
        term_2 = (1 - this_z) ** (this_c - this_a - this_b) * \
                 (gamma(this_c) * gamma(this_a + this_b - this_c) / gamma(this_a) / gamma(this_b)) \
                 * compute_output(this_c - this_a, this_c - this_b, this_c - this_a - this_b + 1.0, 1.0 - this_z)
        return term_1 + term_2

    def branch_5(this_a, this_b, this_c, this_z):
        term_1 = this_z ** (-this_a) * (gamma(this_c) * gamma(this_c - this_a - this_b)
                                        / gamma(this_c - this_a) / gamma(this_c - this_b)) \
                 * compute_output(this_a, this_a - this_c + 1.0, this_a + this_b - this_c + 1.0, 1.0 - 1.0 / this_z)
        term_2 = this_z ** (this_a - this_c) * (1 - this_z) ** (this_c - this_a - this_b) \
                 * (gamma(this_c) * gamma(this_a + this_b - this_c) / gamma(this_a) / gamma(this_b)) \
                 * compute_output(this_c - this_a, 1.0 - this_a, this_c - this_a - this_b + 1.0, 1.0 - 1.0 / this_z)
        return term_1 + term_2

    def branch_6(this_a, this_b, this_c, this_z):
        term_1 = (-this_z) ** (-this_a) * (
                gamma(this_c) * gamma(this_b - this_a) / gamma(this_b) / gamma(this_c - this_a)) \
                 * compute_output(this_a, this_a - this_c + 1.0, this_a - this_b + 1.0, 1.0 / this_z)
        term_2 = (-this_z) ** (-this_b) * (
                gamma(this_c) + gamma(this_a - this_b) / gamma(this_a) / gamma(this_c - this_b)) \
                 * compute_output(this_b - this_c + 1.0, this_b, this_b - this_a + 1.0, 1.0 / this_z)
        return term_1 + term_2

    branches = [branch_1, branch_2, branch_3, branch_4, branch_5, branch_6]

    def single_computation(val, branches_):
        case_, a_, b_, c_, z_ = val

        if debug_mode:
            def switch_fun(index, branches, *operands):
                index = jnp.clip(0, index, len(branches) - 1)
                return branches[index](*operands)

        else:
            switch_fun = jax.lax.switch

        return switch_fun(case_, branches_, a_, b_, c_, z_)

    # Compute outputs
    if debug_mode:
        def map_fun(f, xs):
            return np.stack([f(x) for x in zip(*xs)])
    else:
        map_fun = jax.lax.map

    all_outputs = map_fun(lambda val: single_computation(val, branches), (cases, a, b, c, z))

    return all_outputs
jguerra-astro commented 1 year ago

@FloList Thank you for starting on this. I'm still trying to wrap my head around it, but it does seem to fail for z <=-1.

I doubled checked the branch_1 function and don't see anything wrong. I'm not sure if you or anyone else might see something I don't. Thanks in advance!

Note: the only modification I made was changing gammaln to jax.scipy.special.gammaln

FloList commented 1 year ago

@jguerra-astro Thanks for taking a look at this - I haven't looked into this any further unfortunately, but there is also the function hyp2f1_small_argument provided by the TF probability Jax substrate, which might be helpful: https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/math/hypergeometric/hyp2f1_small_argument

ColtAllen commented 10 months ago

As one of the contributors to the pymc-marketing library, I've an interest in this as well. It could dramatically reduce training times for one of our models. Our particular hyp2f1 application doesn't involve any unusual edge cases, so the code provided by @FloList may suffice.

ahuang314 commented 3 months ago

@FloList I've taken a look at your code, did some testing, and found several issues. I haven't gotten around to fixing them all yet (and don't know if I will), so I thought I'd just mention what I've spotted in case someone else wants to fix it.

The first problem is that your definition of gamma for negative arguments is not correct. The gamma function isn't an odd function so you can't just move the negative sign to the outside. This explains why @jguerra-astro got some incorrect results in the case where z <= -1. The branch_4 function also sometimes gets incorrect results for the same reason, and I imagine some of the other branch functions do as well. Using jax.scipy.special.gamma fixes the problem.

The second problem is that when you set up the conditions for cases, you are only looking at Re(z) which doesn't work. Take for example z = 0.3 + 2.5i. The code will execute branch_3, which is incorrect because the series representation of 2F1 diverges outside of the unit disk. Similar problems arise with branch_4 with e.g. z = 0.7 + 2.5i, and potentially other branches as well. The conditions on when to use which branch needs to be worked out more carefully.

The third problem is that there is a typo in branch_6. I'm guessing you took it directly from equation 4.20 in the MSc thesis you linked, but that equation has a typo on the second line. The arguments b and b-c+1 in 2F1 need to be swapped (check e.g. Abramowitz and Stegun equation 15.3.7). However, even after fixing this issue, branch_6 still gives incorrect results, and I haven't been able to figure out why.

I did code up a function which can calculate 2F1 for cases where z is outside of the unit disk (more specifically, this approach works whenever z is outside of the circle |z-1/2| = 1/2). This implementation is based off of the analytic continuation equations 4.21 and 4.22 in the MSc thesis. This takes into account the e^{\pm i pi/3} issue, and does not have any poles whenever c - a - b is an integer, but still has poles whenever b - a is an integer.

@jax.jit
def hyp2f1_continuation(a,b,c,z):

    # d0 = 1 and d_{-1} = 0
    prev_da = 1.
    prev_db = 1.
    prev_prev_da = 0.
    prev_prev_db = 0.

    # partial_sum_1 corresponds to the summation on the top line of equation 4.21
    # partial_sum_2 corresponds to the summation on the bottom line of equation 4.21
    partial_sum_1 = 1.
    partial_sum_2 = 1.

    # If z is on the branch cut, take the value above the branch cut
    z = jnp.where(jnp.imag(z) == 0., jnp.where(jnp.real(z)>=1., z + 0.0000001j, z), z)

    def body_fun(j, val):
        a_, b_, c_, z_, prev_prev_da, prev_prev_db, prev_da, prev_db, partial_sum_1, partial_sum_2 = val

        #------------------------------------------------------------------------------------------------------
        # This section of the function handles the summation on the first line of equation 4.21
        # calculates d_j and the corresponding term in the sum
        d_ja = (j+a_-1.)/(j*(j+a_-b_)) * (((a_+b_+1.)*0.5-c_)*prev_da + 0.25*(j+a_-2.)*prev_prev_da)
        partial_sum_1 += d_ja * (z - 0.5)**(-j)

        # updates d_{j-2} and d_{j-1}
        prev_prev_da = prev_da
        prev_da = d_ja
        #------------------------------------------------------------------------------------------------------
        # This section of the function handles the summation on the second line of equation 4.21
        # calculates d_j and the corresponding term in the sum
        d_jb = (j+b_-1.)/(j*(j-a_+b_)) * (((a_+b_+1)*0.5-c_)*prev_db + 0.25*(j+b_-2.)*prev_prev_db)
        partial_sum_2 += d_jb * (z - 0.5)**(-j)

        # updates d_{j-2} and d_{j-1}
        prev_prev_db = prev_db
        prev_db = d_jb

        return [a, b, c, z, prev_prev_da, prev_prev_db, prev_da, prev_db, partial_sum_1, partial_sum_2]

    result = jax.lax.fori_loop(1, 30, body_fun, [a, b, c, z, prev_prev_da, prev_prev_db, prev_da, prev_db, partial_sum_1, partial_sum_2])

    # includes the gamma function prefactors in equation 4.21 to compute the final result of 2F1
    final_result = gamma(c) * (result[8] * gamma(b-a)/gamma(b)/gamma(c-a)*(0.5-z)**(-a) + \
                               result[9] * gamma(a-b)/gamma(a)/gamma(c-b)*(0.5-z)**(-b)
                              )
    return final_result