Open jmsull opened 4 years ago
Thanks for the feature request! cc @srvasude just to be in the loop about special functions.
We don't have any current plans to add this, but we love feature requests (and +1's on them) because they help us prioritize.
This might actually be a good first issue for contributors to take on. We'd basically just need to implement this in pure Python + NumPy, and rely on the XLA compiler (via jit
) to generate a decently-optimized compiled function. Take a look at, for example, ndtr
and ndtri
in jax/scipy/special.py to see how that might look. (We also rely on XLA to compile functions for gammaln
and betainc
, but instead of writing those in Python and just using jit
to build the XLA computation, the XLA program is built in C++ for the sole purpose of sharing its implementation with TensorFlow. But for that sharing consideration, performance-wise it's just as good to write things in pure JAX NumPy and use jit
, and it's easier to develop that way too!)
I think adding a link to a implementation of hyp2f1
(or a mathematical formula) using gammaln
and betainc
may help
@mattjj Thanks for the info!
@joaogui1
The relationship between hyp2f1
and betainc
+ gammaln
(specified on the jax documentation for betainc
, also eqn. 3 here) is only valid for the special combination of arguments hyp2f1(a,1-b,a+1,x)
, and valid only when 1>x>0. My application requires handling of the case where -1<x<0. This negative value of the argument is implemented in scipy.special.hyp2f1 (and is mentioned here) via a transformation which changes the arguments of hyp2f1
such that when the first argument is a, the third argument is no longer a+1, so the ability to write hyp2f1
in terms of betainc
+ gammaln
goes away.
A more general handling of hyp2f1
as scipy does is I think required.
Hey @jmsull I started implementing this over the weekend (rather translating the implementation you linked to) but life got in the way and I haven't finished it yet, Are you still interested in a JAX implementation of this function? I think I should be done by Friday, modulo some debugging.
@joaogui1
Still interested - great to hear you started working on it!
What is the status of this issue? I want to start making contributions to this project but wondering if this issue is a good place to start.
@FrescoFlacko life got in the way and I never finished it, but it's a rather large and tedious implementation, so I don't think it's a good first issue
@FrescoFlacko life got in the way and I never finished it, but it's a rather large and tedious implementation, so I don't think it's a good first issue
can you guide me how to go about it please.
@milind-soni I looked around Scipy until I found the implementation and then proceeded to translate it to jax, guess that's the general idea (the implementation was probably written in C though, so you'll need some understanding of it to do a good port)
Not sure if there's still any interest in this topic, but the following function is a workaround that uses the method in John Pearson's MSc thesis based on a Taylor series approximation (see Sec. 4.2). The code below doesn't treat various corner cases such as when differences between a, b, and c are integer, when z = exp(i pi / 3), etc., but I thought I'll leave it here in case someone wants to build on it
@jax.jit
def hyp2f1(a, b, c, z, n_max_iter=500, tol=1e-10, debug_mode=False):
"""
Hypergeometric function implemented in jax, based on John Pearson's MSc thesis
Computation of Hypergeometric Functions, specifically taylora2f1
:param a: a (1D jax device array)
:param b: b (1D jax device array)
:param c: c (1D jax device array)
:param z: z (1D jax device array), either one value of z for each (a, b, c) or different z-values for a single (a, b, c)
:param n_max_iter: maximum number of iterations
:param tol: tolerance
:param debug_mode: debugging mode, doesn't use jitted operations
:return: Gauss hypergeometric function 1F2(a, b, c, z)
"""
a, b, c, z = jnp.atleast_1d(a), jnp.atleast_1d(b), jnp.atleast_1d(c), jnp.atleast_1d(z)
assert a.ndim == b.ndim == c.ndim == z.ndim == 1
# If only a single value for a, b, c is provided, but multiple values of z: tile a, b, c
if a.shape[0] == b.shape[0] == c.shape[0] == 1 and z.shape[0] > 1:
a = a * jnp.ones_like(z)
b = b * jnp.ones_like(z)
c = c * jnp.ones_like(z)
assert a.shape[0] == b.shape[0] == c.shape[0] == z.shape[0], f"{a.shape[0], b.shape[0], c.shape[0], z.shape[0]}"
def compute_output(a_loc, b_loc, c_loc, z_loc):
def cond_fun(val):
step_, a_, b_, c_, z_, coeff_old_, coeff_new_, sum_new_, tol_, n_max_iter_ = val
cond_steps = jnp.all(step_ < n_max_iter_)
cond_tol = jnp.all(jnp.logical_or(jnp.abs(coeff_old_) / jnp.abs(sum_new_) > tol_,
jnp.abs(coeff_new_) / jnp.abs(sum_new_) > tol_))
return jnp.all(jnp.logical_and(cond_steps, cond_tol))
def body_fun(val):
step_, a_, b_, c_, z_, coeff_old_, coeff_new_, sum_new_, tol_, n_max_iter_ = val
coeff_new_ = (a_ + step_ - 1.0) * (b_ + step_ - 1.0) / (c_ + step_ - 1.0) * z_ / step_ * coeff_old_
sum_new_ += coeff_new_
return step_ + 1, a_, b_, c_, z_, coeff_new_, coeff_new_, sum_new_, tol_, n_max_iter_
init_val = (jnp.asarray(1.0, ), a_loc * 1.0, b_loc * 1.0, c_loc * 1.0, z_loc * 1.0, jnp.asarray(1.0, ),
a_loc * 1.0, jnp.asarray(1.0, ), jnp.asarray(tol), jnp.asarray(n_max_iter * 1.0))
if debug_mode:
def while_loop(cond_fun, body_fun, init_val):
val = init_val
while cond_fun(val):
val = body_fun(val)
return val
else:
while_loop = jax.lax.while_loop
final_step, _, _, _, _, _, _, sum_out, _, _ = while_loop(cond_fun, body_fun, init_val=init_val)
return sum_out
def gamma(val):
return jnp.where(val >= 0, jnp.exp(gammaln(val)), -jnp.exp(gammaln(val)))
# Set indices for the different cases of the transformation mapping z to within the radius rho < 0.5
# where Taylor series converges fast
cases = jnp.where(z < -1,
1, jnp.where(z < 0,
2, jnp.where(z <= 0.5,
3, jnp.where(z <= 1,
4, jnp.where(z <= 2,
5, 6)))))
cases -= 1 # 0-based indexing
# Define the branches
def branch_1(this_a, this_b, this_c, this_z):
term_1 = (1 - this_z) ** (-this_a) * (
gamma(this_c) * gamma(this_b - this_a) / gamma(this_b) / gamma(this_c - this_a)) \
* compute_output(this_a, this_c - this_b, this_a - this_b + 1.0, 1.0 / (1.0 - this_z))
term_2 = (1 - this_z) ** (-this_b) * (
gamma(this_c) * gamma(this_a - this_b) / gamma(this_a) / gamma(this_c - this_b)) \
* compute_output(this_b, this_c - this_a, this_b - this_a + 1.0, 1.0 / (1.0 - this_z))
return term_1 + term_2
def branch_2(this_a, this_b, this_c, this_z):
return (1 - this_z) ** (-this_a) * compute_output(this_a, this_c - this_b, this_c,
this_z / (this_z - 1.0))
def branch_3(this_a, this_b, this_c, this_z):
return compute_output(this_a, this_b, this_c, this_z)
def branch_4(this_a, this_b, this_c, this_z):
term_1 = (gamma(this_c) * gamma(this_c - this_a - this_b)
/ gamma(this_c - this_a) / gamma(this_c - this_b)) \
* compute_output(this_a, this_b, this_a + this_b - this_c + 1.0, 1.0 - this_z)
term_2 = (1 - this_z) ** (this_c - this_a - this_b) * \
(gamma(this_c) * gamma(this_a + this_b - this_c) / gamma(this_a) / gamma(this_b)) \
* compute_output(this_c - this_a, this_c - this_b, this_c - this_a - this_b + 1.0, 1.0 - this_z)
return term_1 + term_2
def branch_5(this_a, this_b, this_c, this_z):
term_1 = this_z ** (-this_a) * (gamma(this_c) * gamma(this_c - this_a - this_b)
/ gamma(this_c - this_a) / gamma(this_c - this_b)) \
* compute_output(this_a, this_a - this_c + 1.0, this_a + this_b - this_c + 1.0, 1.0 - 1.0 / this_z)
term_2 = this_z ** (this_a - this_c) * (1 - this_z) ** (this_c - this_a - this_b) \
* (gamma(this_c) * gamma(this_a + this_b - this_c) / gamma(this_a) / gamma(this_b)) \
* compute_output(this_c - this_a, 1.0 - this_a, this_c - this_a - this_b + 1.0, 1.0 - 1.0 / this_z)
return term_1 + term_2
def branch_6(this_a, this_b, this_c, this_z):
term_1 = (-this_z) ** (-this_a) * (
gamma(this_c) * gamma(this_b - this_a) / gamma(this_b) / gamma(this_c - this_a)) \
* compute_output(this_a, this_a - this_c + 1.0, this_a - this_b + 1.0, 1.0 / this_z)
term_2 = (-this_z) ** (-this_b) * (
gamma(this_c) + gamma(this_a - this_b) / gamma(this_a) / gamma(this_c - this_b)) \
* compute_output(this_b - this_c + 1.0, this_b, this_b - this_a + 1.0, 1.0 / this_z)
return term_1 + term_2
branches = [branch_1, branch_2, branch_3, branch_4, branch_5, branch_6]
def single_computation(val, branches_):
case_, a_, b_, c_, z_ = val
if debug_mode:
def switch_fun(index, branches, *operands):
index = jnp.clip(0, index, len(branches) - 1)
return branches[index](*operands)
else:
switch_fun = jax.lax.switch
return switch_fun(case_, branches_, a_, b_, c_, z_)
# Compute outputs
if debug_mode:
def map_fun(f, xs):
return np.stack([f(x) for x in zip(*xs)])
else:
map_fun = jax.lax.map
all_outputs = map_fun(lambda val: single_computation(val, branches), (cases, a, b, c, z))
return all_outputs
@FloList Thank you for starting on this. I'm still trying to wrap my head around it, but it does seem to fail for z <=-1.
I doubled checked the branch_1 function and don't see anything wrong. I'm not sure if you or anyone else might see something I don't. Thanks in advance!
Note: the only modification I made was changing gammaln to jax.scipy.special.gammaln
@jguerra-astro Thanks for taking a look at this - I haven't looked into this any further unfortunately, but there is also the function hyp2f1_small_argument provided by the TF probability Jax substrate, which might be helpful: https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/math/hypergeometric/hyp2f1_small_argument
As one of the contributors to the pymc-marketing library, I've an interest in this as well. It could dramatically reduce training times for one of our models. Our particular hyp2f1 application doesn't involve any unusual edge cases, so the code provided by @FloList may suffice.
@FloList I've taken a look at your code, did some testing, and found several issues. I haven't gotten around to fixing them all yet (and don't know if I will), so I thought I'd just mention what I've spotted in case someone else wants to fix it.
The first problem is that your definition of gamma for negative arguments is not correct. The gamma function isn't an odd function so you can't just move the negative sign to the outside. This explains why @jguerra-astro got some incorrect results in the case where z <= -1. The branch_4
function also sometimes gets incorrect results for the same reason, and I imagine some of the other branch functions do as well. Using jax.scipy.special.gamma fixes the problem.
The second problem is that when you set up the conditions for cases
, you are only looking at Re(z) which doesn't work. Take for example z = 0.3 + 2.5i. The code will execute branch_3
, which is incorrect because the series representation of 2F1 diverges outside of the unit disk. Similar problems arise with branch_4
with e.g. z = 0.7 + 2.5i, and potentially other branches as well. The conditions on when to use which branch needs to be worked out more carefully.
The third problem is that there is a typo in branch_6
. I'm guessing you took it directly from equation 4.20 in the MSc thesis you linked, but that equation has a typo on the second line. The arguments b and b-c+1 in 2F1 need to be swapped (check e.g. Abramowitz and Stegun equation 15.3.7). However, even after fixing this issue, branch_6
still gives incorrect results, and I haven't been able to figure out why.
I did code up a function which can calculate 2F1 for cases where z is outside of the unit disk (more specifically, this approach works whenever z is outside of the circle |z-1/2| = 1/2). This implementation is based off of the analytic continuation equations 4.21 and 4.22 in the MSc thesis. This takes into account the e^{\pm i pi/3} issue, and does not have any poles whenever c - a - b is an integer, but still has poles whenever b - a is an integer.
@jax.jit
def hyp2f1_continuation(a,b,c,z):
# d0 = 1 and d_{-1} = 0
prev_da = 1.
prev_db = 1.
prev_prev_da = 0.
prev_prev_db = 0.
# partial_sum_1 corresponds to the summation on the top line of equation 4.21
# partial_sum_2 corresponds to the summation on the bottom line of equation 4.21
partial_sum_1 = 1.
partial_sum_2 = 1.
# If z is on the branch cut, take the value above the branch cut
z = jnp.where(jnp.imag(z) == 0., jnp.where(jnp.real(z)>=1., z + 0.0000001j, z), z)
def body_fun(j, val):
a_, b_, c_, z_, prev_prev_da, prev_prev_db, prev_da, prev_db, partial_sum_1, partial_sum_2 = val
#------------------------------------------------------------------------------------------------------
# This section of the function handles the summation on the first line of equation 4.21
# calculates d_j and the corresponding term in the sum
d_ja = (j+a_-1.)/(j*(j+a_-b_)) * (((a_+b_+1.)*0.5-c_)*prev_da + 0.25*(j+a_-2.)*prev_prev_da)
partial_sum_1 += d_ja * (z - 0.5)**(-j)
# updates d_{j-2} and d_{j-1}
prev_prev_da = prev_da
prev_da = d_ja
#------------------------------------------------------------------------------------------------------
# This section of the function handles the summation on the second line of equation 4.21
# calculates d_j and the corresponding term in the sum
d_jb = (j+b_-1.)/(j*(j-a_+b_)) * (((a_+b_+1)*0.5-c_)*prev_db + 0.25*(j+b_-2.)*prev_prev_db)
partial_sum_2 += d_jb * (z - 0.5)**(-j)
# updates d_{j-2} and d_{j-1}
prev_prev_db = prev_db
prev_db = d_jb
return [a, b, c, z, prev_prev_da, prev_prev_db, prev_da, prev_db, partial_sum_1, partial_sum_2]
result = jax.lax.fori_loop(1, 30, body_fun, [a, b, c, z, prev_prev_da, prev_prev_db, prev_da, prev_db, partial_sum_1, partial_sum_2])
# includes the gamma function prefactors in equation 4.21 to compute the final result of 2F1
final_result = gamma(c) * (result[8] * gamma(b-a)/gamma(b)/gamma(c-a)*(0.5-z)**(-a) + \
result[9] * gamma(a-b)/gamma(a)/gamma(c-b)*(0.5-z)**(-b)
)
return final_result
It would be great if hyp2f1 could be added under jax.scipy.special! Are there any plans to do this?
My application uses a special case of the hyp2f1 that seems like it can be written using special functions already implemented in jax (gammaln and betainc, but it turns out that the scipy implementation of hyp2f1 uses a transformation on the input that prevents hyp2f1 from being written in terms of these other functions. As far as I can tell there is not another convenient substitution in terms of existing jax functions.
Thanks!