jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

custom transposition #9129

Open froystig opened 2 years ago

froystig commented 2 years ago

Support custom transposition, i.e. the ability to register a custom "transpose rule" for any given function.

A function f marked for custom transposition and its transpose t take two arguments each. Their signatures are related as:

f :: r -> a -> b
t :: r -> b -> a

where the r argument represents residuals. Typically, we would use the notation a -o b to mean "a structurally linear function a -> b" and say that if t is a transpose of f then:

f :: r -> a -o b
t :: r -> b -o a

But we won't actually require structural linearity here, only "numerical" linearity. That's part of the very point of customization.

Example usage would look something like:

from functools import partial
from jax import custom_transpose, linear_transpose, numpy as jnp

def transpose_unary(f, x_example):
  def transposed(y):
    x, = linear_transpose(f, x_example)(y)
    return x
  return transposed

T = lambda f: transpose_unary(f, 0.)

# -- one degree of custom transposition

@custom_transpose
def f(_, z):
  return 2. * z

@f.def_transpose
def ft(_, z):
  return 3. * z

f = partial(f, ())
print(f(1.))              # 2.
print(T(f)(1.))           # 3.
print(T(T(f))(1.))        # 3.
print(T(T(T(f)))(1.))     # 3.
print(T(T(T(T(f))))(1.))  # 3. ...

# -- two degrees of custom transposition

@custom_transpose
def f(_, z):
  return 2. * z

@f.def_transpose
@custom_transpose
def ft(_, z):
  return 3. * z

@ft.def_transpose
def ftt(_, z):
  return 7. * z

f = partial(f, ())
print(f(1.))              # 2.
print(T(f)(1.))           # 3.
print(T(T(f))(1.))        # 7.
print(T(T(T(f)))(1.))     # 7.
print(T(T(T(T(f))))(1.))  # 7. ...

# -- symmetrically registered transposes (arbitrary degree)

@custom_transpose
def f(_, z):
  return 2. * z

@custom_transpose
def g(_, z):
  return 3. * z

f.def_transpose(g)
g.def_transpose(f)

f = partial(f, ())
print(f(1.))              # 2.
print(T(f)(1.))           # 3.
print(T(T(f))(1.))        # 2.
print(T(T(T(f)))(1.))     # 3.
print(T(T(T(T(f))))(1.))  # 2. ...

# recursively registered transposes (arbitrary degree)

@custom_transpose
def f(c, z):
  return c * z

@f.def_transpose
def ft(c, z):
  return f(c + 1., z)

g = partial(f, 1.)
print(g(1.))              # 1.
print(T(g)(1.))           # 2.
print(T(T(g))(1.))        # 3.
print(T(T(T(g)))(1.))     # 4.
print(T(T(T(T(g))))(1.))  # 5. ...

This is likely to subsume linear_call from #5781.

Application: fewer primitives

Among other things, lax.custom_root and lax.custom_linear_solve may no longer need to be primitives. When used together with jax.custom_jvp, this would enable customization of both forward- and reverse-mode AD (custom_vjp currently disallows forward-mode). To illustrate this with a simplified linear solve function:

from jax import custom_jvp
long_iterative_solve = jnp.linalg.solve  # for illustration

@custom_jvp
@custom_transpose
def solve(A, b):
  """find x such that A @ x = b"""
  return long_iterative_solve(A, b)

@solve.def_transpose
def solve_transpose(A, tb):
  return solve(A.T, tb)

@solve.defjvp
def solve_jvp(primals, tangents):
  A, b = primals
  tA, tb = tangents
  x = solve(A, b)
  tx = solve(A, tb - tA @ x)    # automatically transposed for VJP
  return x, tx

JAX derives VJPs from JVPs by linearization and transposition. In this case, the system will pick up the custom transpose of solve and the derived VJP will also carry out a solve (against the transposed design matrix). More generally, forward-mode (JVP) behavior is altered by custom_jvp as usual, and reverse-mode (VJP) behavior is altered by any custom transposes in the dependence path of the tangent output (tx).

Application: upgrading custom VJPs

We could imagine re-implementing our jax.custom_vjp functionality on top of custom_jvp and custom_transpose. A sketch:

def disallow_jvp(*_):
  raise TypeError("can't apply forward-mode AD (jvp) to a custom_vjp function.")

def custom_vjp2(fun, fwd, bwd):
  tan_fn = custom_transpose(disallow_jvp)
  tan_fn.def_transpose(bwd)
  fun = custom_jvp(fun)

  @fun.defjvp
  def jvp(primals, tangents):
    outs, residuals = fwd(*primals)
    return outs, tan_fn(residuals, tangents)

  return fun

This recovers the current custom_vjp behavior. It also opens up possibilities for forward-mode AD support in the presence of custom_vjp—something the current implementation doesn't support. There are at least two options for defining forward-mode behavior...

Upgrade option 1: derive a JVP by transposing the VJP

This requires that the bwd function supplied to custom_vjp be structurally linear.

def custom_vjp3(fun, fwd, bwd):
  bwd_t = custom_transpose(lambda res, tan: linear_transpose(partial(bwd, res), tan))
  bwd_t.def_transpose(bwd)
  fun = custom_jvp(fun)

  @fun.defjvp
  def jvp(primals, tangents):
    outs, residuals = fwd(*primals)
    return outs, bwd_t(residuals, tangents)

  return fun

Upgrade option 2: derive a JVP by linearizing the "forward" function

This requires that the fwd function supplied to custom_vjp be (forward-mode) automatically differentiable.

def custom_vjp4(fun, fwd, bwd):
  fun = custom_jvp(fun)
  fwd_out = lambda *primals: fwd(*primals)[0]
  fwd_res = lambda *primals: fwd(*primals)[1]

  @fun.defjvp
  def jvp(primals, tangents):
    outs, tan_fn = jax.linearize(fwd_out, *primals)
    residuals = fwd_res(*primals)
    tan_fn_aug = custom_transpose(lambda res, tan: tan_fn(*tan))
    tan_fn_aug.def_transpose(bwd)
    return outs, tan_fn_aug(residuals, tangents)

  return fun

(This repeats some work by interpreting fwd twice—once to compute primals and to linearize, and then again to grab the user-defined custom residuals. We could avoid this by using our "auxiliary output" machinery in AD to obtain those residuals concurrently with the linearization process, e.g. if we exposed the hax_aux option of ad.linearize up through jax.linearize.)

On linearity assumptions

Although a custom_transpose'd function and its transpose rule need not be structurally linear, JAX may assume that they are nonetheless mathematically linear. This assumption is somewhat inevitable, and we should highlight it in documentation. An example consequence of this assumption is that if one writes:

f = custom_transpose(f)
f.def_transpose(g)

for some functions f and g, then JAX's AD system may consider the following a valid JVP for f:

def f_jvp(primals, tangents):
  out_primals = f(*primals)
  out_tangents = f(*tangents)
  return out_primals, out_tangents

This is hopefully an unsurprising requirement, since the notion of transposition only applies to linear maps to begin with. That said, we could imagine applications that might willingly break the linearity requirement. An example is the following (arguably) natural implementation of "gradient clipping", as is somewhat common in neural network training:

@custom_transpose
def clip(threshold, x):
  return jnp.clip(x, -threshold, threshold)  # nb: not linear!
clip.def_transpose(clip)

@custom_jvp
def clip_tangent(threshold, x):
  return x

@clip_tangent.defjvp
def clip_tangent_jvp(primals, tangents):
  threshold, x = primals
  _, tx = tangents
  y = clip_tangent(threshold, x)
  ty = clip(threshold, tx)
  return y, ty

If clip is itself ever automatically differentiated, the caller might be surprised that its derivative is threshold past threshold, rather than 0, even though the derivative of clip was never explicitly customized.

shoyer commented 2 years ago

Very nice, I'm excited about this!

Note that custom_jvp already suffices for lax.custom_root (supposing that the linear solve is transposable), but not lax.custom_linear_solve (which need transposition).

mattjj commented 2 years ago

Another major motivator is odeint: right now it has a custom_vjp rule, precluding forward mode differentiation. But we know how to decompose ODE differentiation into forward mode, partial evaluation, and transposition. We just need a way to register custom transposes for things like linear ODE solves!

froystig commented 2 years ago

I've added a section to the issue description that covers how custom transposition, once realized, might allow for a re-implementation of jax.custom_vjp, either as it behaves today or with forward-mode AD support in some form.

froystig commented 2 years ago

I've added another section highlighting the linearity requirement on the target function.