Open jessebett opened 4 years ago
To anyone who wants to follow along, (or contribute!) these rules all found at the bottom of jax/experimental/jet.py
If we want to be comprehensive should we also include sinh
, cosh
, asinh
, acosh
, digamma
?
I think digamma
, igamma
, igammac
, and lgamma
are all impossible due to Hölder's theorem.
Wow, great point Jacob. I suppose there's the slim possibility that the ODE coefficients would be expressible in terms of the gamma function itself?
Yeah I think that's possible! But I think such a scenario wouldn't characterize the gamma function, see here.
If there's interest in defining jet rules for numerical linear algebra, I would suggest referencing AlgoPy and Sebastian Walter's PhD thesis.
AlgoPy seems to have higher order auto-diff implementations for the full range of JAX's linear algebra primitives: linear_solve
(which covers both custom_linear_solve
and triangular_solve
), cholesky
, lu
, qr
, eigh
and svd
.
@shoyer this is an awesome reference! Thanks for making me aware of it I'm looking through it now and agree that these are great sources for those linalg primitives!
Hi all!
Sorry to jump in from the side on this discussion. I am a big fan of jet
, and some more primitives would be super useful to me!
Are there still plans on adding some of the missing pieces in the future?
For example, jnp.roll
requires the dynamic_slice
rule, which stops me from using jet
on the Lorenz96 ODE problem (there are workarounds for Lorenz96, but I found the roll
implementation to be most efficient/readable).
import jax.numpy as jnp
from jax.experimental.jet import jet
def f_lorenz96(y, c=1.):
A = jnp.roll(y, shift=-1)
B = jnp.roll(y, shift=2)
C = jnp.roll(y, shift=1)
D = y
return (A - B) * C - D + c
y0 = jnp.ones(11) # values dont matter
jet(f_lorenz96, (y0,), ((y0,),))
which raises a KeyError
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "example.py", line 16, in <module>
jet(f_lorenz96, (y0,), ((y0,),))
File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/experimental/jet.py", line 57, in jet
out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "example.py", line 6, in f_lorenz96
A = jnp.roll(y, shift=-1)
File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5409, in roll
return _roll(a, shift, axis)
File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/experimental/jet.py", line 140, in process_call
result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5386, in _roll
return lax.reshape(_roll(ravel(a), shift, axis=0), a_shape)
File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/experimental/jet.py", line 140, in process_call
result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5400, in _roll
a = lax.dynamic_slice_in_dim(a, a_shape[i] - x, a_shape[i], axis=i)
File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/experimental/jet.py", line 126, in process_primitive
rule = jet_rules[primitive]
KeyError: dynamic_slice
I could make a PR with a little guidance (for example, if pointed to the correct primitive). Also, please let me know if this is too off-topic for this discussion and I should open a new issue instead. :)
Thanks a lot!
Python: 3.8.10 Pip freeze:
absl-py==1.0.0
flatbuffers==2.0
jax==0.2.28
jaxlib==0.1.76
numpy==1.22.2
opt-einsum==3.3.0
scipy==1.7.3
six==1.16.0
typing-extensions==4.0.1
That's great that you're using jets! I can help you debug. I'm not sure this will work, but my guess is that, like slice
, dynamic_slice
is linear. In those cases, the jet rule is to simply apply the same operation to all the elements of the series. We made a generic helper for such functions called def_linear
.
Here it is applied to slice: https://github.com/google/jax/blob/main/jax/experimental/jet.py#L240
So I think you might just be able to add a line for dynamic_slice
everywhere you see a line for slice
.
Great, thanks a lot for the quick reply! I will give this a try.
Would you like me to add a test (e.g. to https://github.com/google/jax/blob/7a6986c4c8fd8469bae36306efb0417b0a2f6d8c/tests/jet_test.py#L306) and make a pull request?
Yes, please!
I think dynamic_slice
can now be checked off in the list above. I cannot do it myself, which is why I comment.
If you need more help, please let me know. Otherwise I will come back next time I need a jet primitive that has not been implemented yet :)
Thanks for the support!
I checked it for you! :)
Thanks @pnkraemer ! This module is getting to have a real "stone soup" vibe, it's great.
I've had another look, and many of the primitives above are already implemented. It probably makes sense to check them off in the todo list above. :)
More specifically (in the order of appearance):
expit
todo or not. @pnkraemer Thanks for going through and making this list! I can confirm that the jet rule for expit
is implemented and works based on these tests:
https://github.com/google/jax/blob/main/tests/jet_test.py#L278-L281
Have you considered a generic polynomial composition routine to propagate Taylor series? It could be used for g(f(x)) whenever the higher derivatives of g(z) with respect to z are available and fill in cases where specialized routines aren't available. I implemented a basic one that runs in O(K^3) time and tested it for lgamma
and digamma
:
def poly_mul(f, g):
'''Multiply polynomials, truncating to larger degree'''
f, g = (f, g) if len(f) >= len(g) else (g, f) # ensure f has larger degree
k = len(f)
out = [jnp.zeros_like(f[0]) for i in range(k)]
# naive convolution
for i in range(k):
g_lim = np.minimum(len(g), k-i)
for j in range(g_lim):
out[i+j] += f[i]*g[j]
return out
def poly_compose(g, f):
'''Polynomial composition via Horner's method (recursive)'''
if not(g):
return [jnp.zeros_like(f[0]) for term in f]
res = poly_mul(f, poly_compose(g[1:], f))
return [g[0] + res[0]] + res[1:]
def _series_to_taylor(f):
# Add zero term and scale with factorials
return [jnp.zeros_like(f[0])] + [term/fact(i+1) for i, term in enumerate(f)]
def _taylor_to_series(f):
# Unscale and remove constant term
return [term*fact(i+1) for i, term in enumerate(f[1:])]
def _compose_taylor(g, f):
# g and f are sequences of derivatives
# convert to Taylor polynomials and compose
g = _series_to_taylor(g)
f = _series_to_taylor(f)
h = poly_compose(g, f)
return(_taylor_to_series(h))
def _faa_di_bruno_rule(fun, local_deriv_fun, primals_in, series_in):
x, = primals_in
series, = series_in
primals_out = fun(x)
series_local = local_deriv_fun(x, len(series))
series_out = _compose_taylor(series_local, series) if series else []
return primals_out, series_out
_lgamma_local_derivs = lambda x, k: [jax.scipy.special.polygamma(n, x) for n in range(k)]
jet_rules[lax.lgamma_p] = partial(_faa_di_bruno_rule, lax.lgamma, _lgamma_local_derivs)
_digamma_local_derivs = lambda x, k: [jax.scipy.special.polygamma(n, x) for n in range(1, k+1)]
jet_rules[lax.digamma_p] = partial(_faa_di_bruno_rule, lax.digamma, _digamma_local_derivs)
The implementation could probably be optimized. Also, for large K asymptotically faster algorithms for composition (e.g., Brent and Kung, 1978) might be beneficial.
Hi @dsheldon thanks for your comment, this is awesome :) Really clean implementation and nice examples!
Feel free to make a PR for this :D
I believe @dougalm, @duvenaud, @jessebett, @mattjj worked on an implementation of FDB with some optimisations here, but we never managed to merge this into the jet API. So feel free to have a peak at that code if you're keen and incorporate some of those ideas into your PR too :)
IIRC @jessebett implemented some ideas from Brent and Kung, but I can't find the code for that unfortunately. @jessebett would you be able to provide a pointer?
Hi @jacobjinkelly, thanks for the response! Sure, I'd be happy to make a PR. Before I do: did I make the correct assumption about the inputs to the propagation rules? I.e., that series
in the _faa_di_bruno_rule
above will always be a list of arrays (or array-like objects) with the same shape?
I think the other implementation you pointed to is iterating through integer partitions, which, to my understanding, would take super-polynomial time. With polynomial composition it's possible to implement Faa di Bruno's rule in polynomial time — e.g., $O(k^3)$ for the naive algorithm above. One of the reasons for my comment was to point out this possibility.
We used one of the Brent-Kung algorithms in a paper, so I have some old code, but I'm guessing the crossover point where it will be substantially faster than Horner's rule doesn't occur until the order of differentiation is pretty large. (And for very large $k$ it's hard to get around numerical issues due to dealing with terms like $k!$ in floating point anyway.)
Yes I think that's the correct assumption for the type of series
. If you're not sure, you can always double check by using the unary_check
utility in jet_test.py.
I think the other implementation you pointed to is iterating through integer partitions, which, to my understanding, would take super-polynomial time. With polynomial composition ...
Ok wow I didn't realise that, that's super interesting!
substantially faster than Horner's rule doesn't occur until the order of differentiation is pretty large ...
Agreed!
Hi @dsheldon,
As Jacob already said, this is a very nice implementation!
It's been some time so I'm a bit fuzzy on the details, sorry in advance if I'm missing your main points.
Have you considered a generic polynomial composition routine to propagate Taylor series?
Yes! Early implementations we worked with exclusively relied on Faa Di Bruno and generic polynomial propagation. We included some notes on the relationship between Faa Di Bruno formula and propagation of truncated Taylor polynomials in the appendices of our workshop preprint.
As you noted, Brent and Kung (1978) suggests asymptotically faster algorithms for the general polynomial composition (following from algorithmic complexity for inverting power series). However, as described in Griewank and Walther (2008) Chapter 13.2 on Taylor Polynomial Propagation, specifically the subsection on Nonlinear Univariate Elementals, generic polynomial composition addresses the general case where the Taylor coefficients are unrelated. But in practice the coefficients for "elemental functions of interest" (e.g. primitives like exp
and sin
) all are solutions of linear ODEs and satisfy the identity in equation 13.7. This allows application of Brent and Kung Theorem 5.1 to derive more efficient algorithms by exploiting relationships between coefficients.
So at some point we moved from generic Faa Di Bruno based propagation rules to more optimized rules for primitives that are a consequence of this special relationship among their coefficients. You can find some examples of these rules and their complexity in tables 13.1 and 13.2 of Griewank and Walther.
To be clear, Brent and Kung discuss algorithms for efficient propagation of power series with better complexity than classical polynomial propagation, but those improvements are really only relevant at larger orders of differentiation. You already point out this is numerically tricky due to $k!$. Instead, we are appealing to a restriction from the general problem of polynomial composition to those with special structure, which is a separate result also discussed in Brent and Kung.
did I make the correct assumption about the inputs to the propagation rules? I.e., that series in the
_faa_di_bruno_rule
above will always be a list of arrays (or array-like objects) with the same shape?
Yes you got it. To help orient through the current code, you can see/modify the enforcement of the series input type here and the enforcement that the series have the same shape here.
I think the other implementation you pointed to is iterating through integer partitions, which, to my understanding, would take super-polynomial time.
Yes, Jacob linked to an implementation preceding our use of Faa Di Bruno rules that relies on integer partitions. Your understanding is correct, this was very non-optimal! A subsequent version that uses Faa Di Bruno exclusively is somewhere in the commit/branch history, but I'd have to do some digging to find it...
It could be used for
g(f(x))
whenever the higher derivatives ofg(z)
with respect toz
are available and fill in cases where specialized routines aren't available.
I agree, this would be great to have! At some point we discussed keeping the generic Faa Di Bruno rule as a fallback for primitives without special rules. I'm not quite sure what happened that dropped the generic rule, but I suspect we just got distracted by working out special rules for primitives we encountered in practice and the general case became increasingly rare. As Jacob said, if you're up for it a PR would be very welcome!
Thanks again for your comment and clean code!
OK, PR created --- hopefully done correctly, if not please let me know!
@jessebett, thanks for your comments, I think we're on the same page regarding the options! The optimized rules (usually O(k^2)) when the function is the solution of an ODE are preferable and applicable in almost all cases. I made my comment because the notes explaining jet mention (1) the intractable direct evaluation of Faa Di Bruno's formula, and (2) the efficient special case for common primitives that are the solution of a linear ODE —but they don't mention that a generic rule based on polynomial composition runs in O(k^3) (or better) time, which is slower but not disastrously slower than the specialized rules. So it seemed like such a rule could be helpful for the few primitives that weren't already covered by special cases. From your response, it sounds like you already had this in previous versions.
@dsheldon thanks for the tips!
I remember trying polynomial composition, but I thought it didn't give the correct answers (even though I thought it should work), and I never figured out why! I think I messed up _series_to_taylor
and _taylor_to_series
(referring to the names in your PR). It's great if this works generically!
@jacobjinkelly @duvenaud @mattjj
Here's the list of all the
ad.jvp_primitives
and probably reflects what we'd need for coverage byjet
.Copied from #2363 since that was more about merging into master!