jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.3k stars 2.78k forks source link

Slow autodiff at runtime -- unable to performantly linearise a linear function #9215

Open patrick-kidger opened 2 years ago

patrick-kidger commented 2 years ago

JAX is unable to obtain the Jacobian of a linear function in an efficient manner.

import jax
import jax.numpy as jnp
import jax.random as jrandom
import timeit

A = jrandom.normal(jrandom.PRNGKey(0), (10000, 10000))
y = jrandom.normal(jrandom.PRNGKey(1), (10000,))
def f(x):
    return A @ x
jac = jax.jit(jax.jacfwd(f))
print(timeit.timeit(lambda: jac(y), number=1))  # compilation
print(timeit.timeit(lambda: jac(y).block_until_ready(), number=1))  # runtime
print(timeit.timeit(lambda: jnp.array(A).block_until_ready(), number=1))  # ideal runtime
assert jnp.all(jac(y) == A)

On CPU with JAX version 0.2.26 and jaxlib version 0.1.75 I get the printout:

5.859835999999632
4.756322400000499  
0.08406309999918449

So despite the fact that the matrix A already exists, and that all jac needs to do is return [a copy of] this existing matrix, it still takes two orders of magnitude too long to accomplish this task.

I'm providing this as a "benchmark MWE" -- my concern is that if even this example fails, how much else is also silently taking far too long?

(On GPU with JAX version 0.2.20 and jaxlib version 0.1.71 I don't get a printout at all -- compilation time exceeds several minutes. I've confirmed that the install is otherwise working. I don't know what's going on with that either.)


Various notes:

{ lambda a:f32[10000,10000]; b:f32[10000]. let
    c:f32[10000,10000] = xla_call[
      call_jaxpr={ lambda ; d:f32[10000,10000] e:f32[10000]. let
          f:i32[10000,10000] = iota[
            dimension=0
            dtype=int32
            shape=(10000, 10000)
          ]
          g:i32[10000,10000] = add f 0
          h:i32[10000,10000] = iota[
            dimension=1
            dtype=int32
            shape=(10000, 10000)
          ]
          i:bool[10000,10000] = eq g h
          j:f32[10000,10000] = convert_element_type[
            new_dtype=float32
            weak_type=False
          ] i
          k:f32[10000,10000] = slice[
            limit_indices=(10000, 10000)
            start_indices=(0, 0)
            strides=None
          ] j
          _:f32[10000] = dot_general[
            dimension_numbers=(((1,), (0,)), ((), ()))
            precision=None
            preferred_element_type=None
          ] d e
          l:f32[10000,10000] = dot_general[
            dimension_numbers=(((1,), (1,)), ((), ()))
            precision=None
            preferred_element_type=None
          ] d k
          m:f32[10000,10000] = slice[
            limit_indices=(10000, 10000)
            start_indices=(0, 0)
            strides=None
          ] l
        in (m,) }
      name=jacfun
    ] a b
  in (c,) }
PhilipVinc commented 2 years ago

I have seen something similar in other occasions. I suspect that A is captured in the JAX's IR as a constant and somehow XLA is unrolling it and trying to optimise the huge blob that results from that. If that's true, I guess that if you look at how the compile time scales with the size of A you'll see that is scales at least quadratically...

zhangqiaorjc commented 2 years ago

@mattjj does JAX know that jac(A) is just A? Even so, we still need to do an xla_call and incur XLA compilation overhead right?

mattjj commented 2 years ago

The compilation time is slow on all backends. Removing AD from the compilation time story, slow compilation time does not reproduce without the large constant array, e.g. not with:

import jax
import jax.numpy as jnp
import jax.random as jrandom
from jax import lax

x = jrandom.normal(jrandom.PRNGKey(1), (10000,))

@jax.jit
def f(x):
  i0 = lax.broadcasted_iota(jnp.int32, (10000, 10000), 0) + 0
  i1 = lax.broadcasted_iota(jnp.int32, (10000, 10000), 1)
  eye = (i0 == i1).astype(x.dtype)
  return jnp.dot(eye, x)

print(jax.make_jaxpr(f)(x))
print(timeit.timeit(lambda: f(x), number=1))  # compilation

But it does reproduce like this:

A = jrandom.normal(jrandom.PRNGKey(0), (10000, 10000))

@jax.jit
def f():
  i0 = lax.broadcasted_iota(jnp.int32, (10000, 10000), 0) + 0
  i1 = lax.broadcasted_iota(jnp.int32, (10000, 10000), 1)
  eye = (i0 == i1).astype(jnp.float32)
  return jnp.dot(A, eye)

print(jax.make_jaxpr(f)())
print(timeit.timeit(lambda: f(), number=1))  # compilation

In fact, it even reproduces like this:

A = jrandom.normal(jrandom.PRNGKey(0), (10000, 10000))
y = jrandom.normal(jrandom.PRNGKey(1), (10000,))

@jax.jit
def g(_):
  2 * _  # don't let jax skip xla compilation
  return A
print(timeit.timeit(lambda: g(y), number=1))  # 6.272046981030144

I'm not sure what's up with this compilation time. I'll raise it as an XLA issue.


As for runtime, indeed there is a missing AD optimization here which we've long known about but has never seemed very practically important. I'd phrase it like this: for a matrix A of shape (N, N), does jax.jacfwd(lambda x: A @ x)(x) require O(1) time to evaluate or O(N^3)? Currently the answer is O(N^3). That is, we basically evaluate A @ jnp.eye(A.shape[0]) rather than just A.

The reason I say this optimization is uninteresting is that it only applies to pretty much just this exact case (of a function which consists of a single dot). It doesn't apply to general nonlinear functions, of course, but it doesn't even apply to general linear functions.

To effect the optimization we could imagine having jacfwd create symbolic identity matrices, and maybe slightly more generally symbolic standard basis vectors, and then having some operations (like dot) know how to handle them specially. But so far no one's really asked for it. Are there really situations where we'd want to automatically discover that such special dense linearity structure exists, rather than a user just exploiting it directly? If you have a real application, I'd love to hear it!

patrick-kidger commented 2 years ago

Definitely glad to hear that the compilation time is being worked on; that's frequently more of an issue for JAX programs than runtime, from what I hear!


That is, we basically evaluate A @ jnp.eye(A.shape[0]) rather than just A.

Yeah, this is what I got from the jaxpr. I inspected the resulting XLA as well, and interestingly enough -- on small enough problems that it compiles in time -- this multiplication-by-overlapping-iotas gets optimised out on the GPU. Just not on the CPU.

I do think I would describe this as an interesting optimisation. For example consider a 2-layer MLP given by x -> A @ σ(B @ x). This has Jacobian A @ diag[σ'((B @ x)_0), ..., σ'((B @ x)_n)] @ B. Computing this Jacobian as A @ diag[σ'((B @ x)_0), ..., σ'((B @ x)_n)] @ B @ I represents a lot of extra work. So I think at minimum this optimisation applies to every nonlinear function that starts with a linear map, and with all the neural networks flying around the place that's probably most of them.

I do suspect symbolic basis vectors is probably overkill. That sounds like it'd be a lot of new abstractions just for this one use case. I feel like the best way to tackle this is to improve the compiler -- either in XLA:CPU (make the current XLA:GPU optimisation backend-agnostic) or to start improving the facilities around jaxpr manipulation. There's already DCE in a few places for jaxprs, maybe it's time to start adding pattern matching as well?


As for applications: absolutely I have applications! So this scenario appears when solving differential equations. The mathematics is pretty interesting so I'll write out a bit of it as I reckon you'll find it interesting too.

When you make a step of e.g. Euler's method on an ODE dy(t) = f(y(t)) dt then this is implemented as y_{n+1} = y_n + f(y_n) (t_{n+1} - t_n). That interaction between f(y_n) and t_{n+1} - t_n is multiplication. But when solving an SDE dy(t) = f(y(t)) dw(t) then we end up writing down y_{n+1} = y_n + f(y_n) (w_{n+1} - w_n) where now the analogous interaction is a matrix-vector product. If I wrote down the more general SDE dy(t) = [f(y(t)), g(y(t))] . [dt, dw(t)] then the interaction is a "dot product" which is a multiplication in one element and a matrix-vector product in the other element. In general, the interaction is a bilinear map.

The nature of this interaction is user-specified and unknown to the me (the library author) as it's a detail of the input differential equation. So at the level of the library it's an abstract bilinear map, which the diffeq solving code is generic with respect to. In Diffrax this is AbstractTerm.prod.

That works 99% of the time. But there are a few cases where we do need to materialise this abstract bilinear map explicitly.

Some solvers need this when backpropagating via the backwards-in-time continuous adjoint ODE. This is discussed in this comment and the Jacobian calculation occurs here.

Meanwhile another example is when using the Itô version of Milstein's method. This is actually a very complicated algorithm when you have abstract bilinear maps and PyTrees floating around. If you're feeling masochistic then the full algorithm is here. (And very heavily commented to be as readable as possible.) In fact this very issue is actually referenced here. The use of Jacobians to materialise the abstract (bi)linear maps occurs here and here.