jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.42k stars 2.79k forks source link

Outstanding Primitive Rules for Higher Order Automatic Differentiation (jet) #2431

Open jessebett opened 4 years ago

jessebett commented 4 years ago

@jacobjinkelly @duvenaud @mattjj

Here's the list of all the ad.jvp_primitives and probably reflects what we'd need for coverage by jet.

Copied from #2363 since that was more about merging into master!

jessebett commented 4 years ago

To anyone who wants to follow along, (or contribute!) these rules all found at the bottom of jax/experimental/jet.py

jacobjinkelly commented 4 years ago

If we want to be comprehensive should we also include sinh, cosh, asinh, acosh, digamma?

jacobjinkelly commented 4 years ago

I think digamma, igamma, igammac, and lgamma are all impossible due to Hölder's theorem.

duvenaud commented 4 years ago

Wow, great point Jacob. I suppose there's the slim possibility that the ODE coefficients would be expressible in terms of the gamma function itself?

jacobjinkelly commented 4 years ago

Yeah I think that's possible! But I think such a scenario wouldn't characterize the gamma function, see here.

shoyer commented 4 years ago

If there's interest in defining jet rules for numerical linear algebra, I would suggest referencing AlgoPy and Sebastian Walter's PhD thesis.

AlgoPy seems to have higher order auto-diff implementations for the full range of JAX's linear algebra primitives: linear_solve (which covers both custom_linear_solve and triangular_solve), cholesky, lu, qr, eigh and svd.

jessebett commented 4 years ago

@shoyer this is an awesome reference! Thanks for making me aware of it I'm looking through it now and agree that these are great sources for those linalg primitives!

pnkraemer commented 2 years ago

Hi all!

Sorry to jump in from the side on this discussion. I am a big fan of jet, and some more primitives would be super useful to me! Are there still plans on adding some of the missing pieces in the future?

For example, jnp.roll requires the dynamic_slice rule, which stops me from using jet on the Lorenz96 ODE problem (there are workarounds for Lorenz96, but I found the roll implementation to be most efficient/readable).

Example:

import jax.numpy as jnp
from jax.experimental.jet import jet 

def f_lorenz96(y, c=1.):
    A = jnp.roll(y, shift=-1)
    B = jnp.roll(y, shift=2)
    C = jnp.roll(y, shift=1)
    D = y
    return (A - B) * C - D + c

y0 = jnp.ones(11) # values dont matter

jet(f_lorenz96, (y0,), ((y0,),)) 

which raises a KeyError

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "example.py", line 16, in <module>
    jet(f_lorenz96, (y0,), ((y0,),))
  File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/experimental/jet.py", line 57, in jet
    out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
  File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "example.py", line 6, in f_lorenz96
    A = jnp.roll(y, shift=-1)
  File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5409, in roll
    return _roll(a, shift, axis)
  File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/experimental/jet.py", line 140, in process_call
    result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
  File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5386, in _roll
    return lax.reshape(_roll(ravel(a), shift, axis=0), a_shape)
  File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/experimental/jet.py", line 140, in process_call
    result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
  File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 5400, in _roll
    a = lax.dynamic_slice_in_dim(a, a_shape[i] - x, a_shape[i], axis=i)
  File "/home/kraemer/Projects/high-dimensional-ode-solver/sth/env/lib/python3.8/site-packages/jax/experimental/jet.py", line 126, in process_primitive
    rule = jet_rules[primitive]
KeyError: dynamic_slice

I could make a PR with a little guidance (for example, if pointed to the correct primitive). Also, please let me know if this is too off-topic for this discussion and I should open a new issue instead. :)

Thanks a lot!

In case it is necessary:

Python: 3.8.10 Pip freeze:

absl-py==1.0.0
flatbuffers==2.0
jax==0.2.28
jaxlib==0.1.76
numpy==1.22.2
opt-einsum==3.3.0
scipy==1.7.3
six==1.16.0
typing-extensions==4.0.1
duvenaud commented 2 years ago

That's great that you're using jets! I can help you debug. I'm not sure this will work, but my guess is that, like slice, dynamic_slice is linear. In those cases, the jet rule is to simply apply the same operation to all the elements of the series. We made a generic helper for such functions called def_linear.

Here it is applied to slice: https://github.com/google/jax/blob/main/jax/experimental/jet.py#L240

So I think you might just be able to add a line for dynamic_slice everywhere you see a line for slice.

pnkraemer commented 2 years ago

Great, thanks a lot for the quick reply! I will give this a try.

Would you like me to add a test (e.g. to https://github.com/google/jax/blob/7a6986c4c8fd8469bae36306efb0417b0a2f6d8c/tests/jet_test.py#L306) and make a pull request?

duvenaud commented 2 years ago

Yes, please!

pnkraemer commented 2 years ago

I think dynamic_slice can now be checked off in the list above. I cannot do it myself, which is why I comment.

If you need more help, please let me know. Otherwise I will come back next time I need a jet primitive that has not been implemented yet :)

Thanks for the support!

apaszke commented 2 years ago

I checked it for you! :)

duvenaud commented 2 years ago

Thanks @pnkraemer ! This module is getting to have a real "stone soup" vibe, it's great.

pnkraemer commented 2 years ago

I've had another look, and many of the primitives above are already implemented. It probably makes sense to check them off in the todo list above. :)

More specifically (in the order of appearance):

jacobjinkelly commented 2 years ago

@pnkraemer Thanks for going through and making this list! I can confirm that the jet rule for expit is implemented and works based on these tests: https://github.com/google/jax/blob/main/tests/jet_test.py#L278-L281

dsheldon commented 1 year ago

Have you considered a generic polynomial composition routine to propagate Taylor series? It could be used for g(f(x)) whenever the higher derivatives of g(z) with respect to z are available and fill in cases where specialized routines aren't available. I implemented a basic one that runs in O(K^3) time and tested it for lgamma and digamma:

def poly_mul(f, g):
  '''Multiply polynomials, truncating to larger degree'''
  f, g = (f, g) if len(f) >= len(g) else (g, f)  # ensure f has larger degree
  k = len(f)
  out = [jnp.zeros_like(f[0]) for i in range(k)]
  # naive convolution
  for i in range(k):
    g_lim = np.minimum(len(g), k-i)
    for j in range(g_lim):
      out[i+j] += f[i]*g[j]
  return out

def poly_compose(g, f):
  '''Polynomial composition via Horner's method (recursive)'''
  if not(g):
    return [jnp.zeros_like(f[0]) for term in f]
  res = poly_mul(f, poly_compose(g[1:], f))
  return [g[0] + res[0]] + res[1:]

def _series_to_taylor(f):
  # Add zero term and scale with factorials
  return [jnp.zeros_like(f[0])] + [term/fact(i+1) for i, term in enumerate(f)]

def _taylor_to_series(f):
  # Unscale and remove constant term
  return [term*fact(i+1) for i, term in enumerate(f[1:])]

def _compose_taylor(g, f):
  # g and f are sequences of derivatives
  # convert to Taylor polynomials and compose
  g = _series_to_taylor(g)
  f = _series_to_taylor(f)
  h = poly_compose(g, f)
  return(_taylor_to_series(h))

def _faa_di_bruno_rule(fun, local_deriv_fun, primals_in, series_in):
  x, = primals_in
  series, = series_in
  primals_out = fun(x)
  series_local = local_deriv_fun(x, len(series))
  series_out = _compose_taylor(series_local, series) if series else []
  return primals_out, series_out

_lgamma_local_derivs = lambda x, k: [jax.scipy.special.polygamma(n, x) for n in range(k)]
jet_rules[lax.lgamma_p] = partial(_faa_di_bruno_rule, lax.lgamma, _lgamma_local_derivs)

_digamma_local_derivs = lambda x, k: [jax.scipy.special.polygamma(n, x) for n in range(1, k+1)]
jet_rules[lax.digamma_p] = partial(_faa_di_bruno_rule, lax.digamma, _digamma_local_derivs)

The implementation could probably be optimized. Also, for large K asymptotically faster algorithms for composition (e.g., Brent and Kung, 1978) might be beneficial.

jacobjinkelly commented 1 year ago

Hi @dsheldon thanks for your comment, this is awesome :) Really clean implementation and nice examples!

Feel free to make a PR for this :D

I believe @dougalm, @duvenaud, @jessebett, @mattjj worked on an implementation of FDB with some optimisations here, but we never managed to merge this into the jet API. So feel free to have a peak at that code if you're keen and incorporate some of those ideas into your PR too :)

IIRC @jessebett implemented some ideas from Brent and Kung, but I can't find the code for that unfortunately. @jessebett would you be able to provide a pointer?

dsheldon commented 1 year ago

Hi @jacobjinkelly, thanks for the response! Sure, I'd be happy to make a PR. Before I do: did I make the correct assumption about the inputs to the propagation rules? I.e., that series in the _faa_di_bruno_rule above will always be a list of arrays (or array-like objects) with the same shape?

I think the other implementation you pointed to is iterating through integer partitions, which, to my understanding, would take super-polynomial time. With polynomial composition it's possible to implement Faa di Bruno's rule in polynomial time — e.g., $O(k^3)$ for the naive algorithm above. One of the reasons for my comment was to point out this possibility.

We used one of the Brent-Kung algorithms in a paper, so I have some old code, but I'm guessing the crossover point where it will be substantially faster than Horner's rule doesn't occur until the order of differentiation is pretty large. (And for very large $k$ it's hard to get around numerical issues due to dealing with terms like $k!$ in floating point anyway.)

jacobjinkelly commented 1 year ago

Yes I think that's the correct assumption for the type of series. If you're not sure, you can always double check by using the unary_check utility in jet_test.py.

I think the other implementation you pointed to is iterating through integer partitions, which, to my understanding, would take super-polynomial time. With polynomial composition ...

Ok wow I didn't realise that, that's super interesting!

substantially faster than Horner's rule doesn't occur until the order of differentiation is pretty large ...

Agreed!

jessebett commented 1 year ago

Hi @dsheldon,

As Jacob already said, this is a very nice implementation!

It's been some time so I'm a bit fuzzy on the details, sorry in advance if I'm missing your main points.

Have you considered a generic polynomial composition routine to propagate Taylor series?

Yes! Early implementations we worked with exclusively relied on Faa Di Bruno and generic polynomial propagation. We included some notes on the relationship between Faa Di Bruno formula and propagation of truncated Taylor polynomials in the appendices of our workshop preprint.

As you noted, Brent and Kung (1978) suggests asymptotically faster algorithms for the general polynomial composition (following from algorithmic complexity for inverting power series). However, as described in Griewank and Walther (2008) Chapter 13.2 on Taylor Polynomial Propagation, specifically the subsection on Nonlinear Univariate Elementals, generic polynomial composition addresses the general case where the Taylor coefficients are unrelated. But in practice the coefficients for "elemental functions of interest" (e.g. primitives like exp and sin) all are solutions of linear ODEs and satisfy the identity in equation 13.7. This allows application of Brent and Kung Theorem 5.1 to derive more efficient algorithms by exploiting relationships between coefficients.

So at some point we moved from generic Faa Di Bruno based propagation rules to more optimized rules for primitives that are a consequence of this special relationship among their coefficients. You can find some examples of these rules and their complexity in tables 13.1 and 13.2 of Griewank and Walther.

To be clear, Brent and Kung discuss algorithms for efficient propagation of power series with better complexity than classical polynomial propagation, but those improvements are really only relevant at larger orders of differentiation. You already point out this is numerically tricky due to $k!$. Instead, we are appealing to a restriction from the general problem of polynomial composition to those with special structure, which is a separate result also discussed in Brent and Kung.

did I make the correct assumption about the inputs to the propagation rules? I.e., that series in the_faa_di_bruno_rule above will always be a list of arrays (or array-like objects) with the same shape?

Yes you got it. To help orient through the current code, you can see/modify the enforcement of the series input type here and the enforcement that the series have the same shape here.

I think the other implementation you pointed to is iterating through integer partitions, which, to my understanding, would take super-polynomial time.

Yes, Jacob linked to an implementation preceding our use of Faa Di Bruno rules that relies on integer partitions. Your understanding is correct, this was very non-optimal! A subsequent version that uses Faa Di Bruno exclusively is somewhere in the commit/branch history, but I'd have to do some digging to find it...

It could be used for g(f(x)) whenever the higher derivatives of g(z) with respect to z are available and fill in cases where specialized routines aren't available.

I agree, this would be great to have! At some point we discussed keeping the generic Faa Di Bruno rule as a fallback for primitives without special rules. I'm not quite sure what happened that dropped the generic rule, but I suspect we just got distracted by working out special rules for primitives we encountered in practice and the general case became increasingly rare. As Jacob said, if you're up for it a PR would be very welcome!

Thanks again for your comment and clean code!

dsheldon commented 1 year ago

OK, PR created --- hopefully done correctly, if not please let me know!

@jessebett, thanks for your comments, I think we're on the same page regarding the options! The optimized rules (usually O(k^2)) when the function is the solution of an ODE are preferable and applicable in almost all cases. I made my comment because the notes explaining jet mention (1) the intractable direct evaluation of Faa Di Bruno's formula, and (2) the efficient special case for common primitives that are the solution of a linear ODE —but they don't mention that a generic rule based on polynomial composition runs in O(k^3) (or better) time, which is slower but not disastrously slower than the specialized rules. So it seemed like such a rule could be helpful for the few primitives that weren't already covered by special cases. From your response, it sounds like you already had this in previous versions.

mattjj commented 1 year ago

@dsheldon thanks for the tips!

I remember trying polynomial composition, but I thought it didn't give the correct answers (even though I thought it should work), and I never figured out why! I think I messed up _series_to_taylor and _taylor_to_series (referring to the names in your PR). It's great if this works generically!