jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.48k stars 2.8k forks source link

Second derivative of scipy.special.gammainc #7922

Open sirmurphalot opened 3 years ago

sirmurphalot commented 3 years ago

I need to take the derivative of a Gamma CDF with respect to the shape parameter alpha. I can do this pretty easily with

from jax import scipy, grad
def gamma_cdf_term(alpha, data_value):
    return scipy.special.gammainc(alpha, data_value)

and Jax's grad function. What doesn't work is that this Gamma CDF derivative is part of a density that I feed into TensorFlow Probability's NoUTurnSampler, which requires a second derivative with respect to alpha. Here is where I hit a wall:

NotImplementedError: Differentiation rule for 'igamma_grad_a' not implemented

which I assume means that this second derivative is not available.

My questions: is my assumption correct? If so, I read on these that you base new features off of user feedback, so let this be my vote that the second derivative get implemented :). If not, is there anyone who can provide guidance on how I might get this second derivative?

zhangqiaorjc commented 3 years ago

@mattjj

hawkinsp commented 3 years ago

@srvasude perhaps also

apaszke commented 3 years ago

Yes looks like we're missing higher-order derivatives for igamma. If we need to define more operations to do this then it might be a challenging task, because both igamma and its derivative are implemented on the XLA level.

hawkinsp commented 3 years ago

Well, it shouldn't be too hard to implement the derivative if someone figures out the math. We use the HLO built in C++ as a convenience and so we can share it with TensorFlow, but there's no functionality improvement over building the same HLO from Python.

sirmurphalot commented 3 years ago

I'm not sure if this is what you need (since there is an integral term), but I have worked out a form for the second derivative. I think you do the partial swap that I do with Leibniz's integral rule, but someone might want to check me.

gammamath