Open patrick-kidger opened 2 years ago
I have seen something similar in other occasions.
I suspect that A
is captured in the JAX's IR as a constant and somehow XLA is unrolling it and trying to optimise the huge blob that results from that.
If that's true, I guess that if you look at how the compile time scales with the size of A
you'll see that is scales at least quadratically...
@mattjj does JAX know that jac(A) is just A? Even so, we still need to do an xla_call and incur XLA compilation overhead right?
The compilation time is slow on all backends. Removing AD from the compilation time story, slow compilation time does not reproduce without the large constant array, e.g. not with:
import jax
import jax.numpy as jnp
import jax.random as jrandom
from jax import lax
x = jrandom.normal(jrandom.PRNGKey(1), (10000,))
@jax.jit
def f(x):
i0 = lax.broadcasted_iota(jnp.int32, (10000, 10000), 0) + 0
i1 = lax.broadcasted_iota(jnp.int32, (10000, 10000), 1)
eye = (i0 == i1).astype(x.dtype)
return jnp.dot(eye, x)
print(jax.make_jaxpr(f)(x))
print(timeit.timeit(lambda: f(x), number=1)) # compilation
But it does reproduce like this:
A = jrandom.normal(jrandom.PRNGKey(0), (10000, 10000))
@jax.jit
def f():
i0 = lax.broadcasted_iota(jnp.int32, (10000, 10000), 0) + 0
i1 = lax.broadcasted_iota(jnp.int32, (10000, 10000), 1)
eye = (i0 == i1).astype(jnp.float32)
return jnp.dot(A, eye)
print(jax.make_jaxpr(f)())
print(timeit.timeit(lambda: f(), number=1)) # compilation
In fact, it even reproduces like this:
A = jrandom.normal(jrandom.PRNGKey(0), (10000, 10000))
y = jrandom.normal(jrandom.PRNGKey(1), (10000,))
@jax.jit
def g(_):
2 * _ # don't let jax skip xla compilation
return A
print(timeit.timeit(lambda: g(y), number=1)) # 6.272046981030144
I'm not sure what's up with this compilation time. I'll raise it as an XLA issue.
As for runtime, indeed there is a missing AD optimization here which we've long known about but has never seemed very practically important. I'd phrase it like this: for a matrix A
of shape (N, N), does jax.jacfwd(lambda x: A @ x)(x)
require O(1) time to evaluate or O(N^3)? Currently the answer is O(N^3). That is, we basically evaluate A @ jnp.eye(A.shape[0])
rather than just A
.
The reason I say this optimization is uninteresting is that it only applies to pretty much just this exact case (of a function which consists of a single dot
). It doesn't apply to general nonlinear functions, of course, but it doesn't even apply to general linear functions.
To effect the optimization we could imagine having jacfwd
create symbolic identity matrices, and maybe slightly more generally symbolic standard basis vectors, and then having some operations (like dot
) know how to handle them specially. But so far no one's really asked for it. Are there really situations where we'd want to automatically discover that such special dense linearity structure exists, rather than a user just exploiting it directly? If you have a real application, I'd love to hear it!
Definitely glad to hear that the compilation time is being worked on; that's frequently more of an issue for JAX programs than runtime, from what I hear!
That is, we basically evaluate
A @ jnp.eye(A.shape[0])
rather than justA
.
Yeah, this is what I got from the jaxpr. I inspected the resulting XLA as well, and interestingly enough -- on small enough problems that it compiles in time -- this multiplication-by-overlapping-iotas gets optimised out on the GPU. Just not on the CPU.
I do think I would describe this as an interesting optimisation. For example consider a 2-layer MLP given by x -> A @ σ(B @ x)
. This has Jacobian A @ diag[σ'((B @ x)_0), ..., σ'((B @ x)_n)] @ B
. Computing this Jacobian as A @ diag[σ'((B @ x)_0), ..., σ'((B @ x)_n)] @ B @ I
represents a lot of extra work. So I think at minimum this optimisation applies to every nonlinear function that starts with a linear map, and with all the neural networks flying around the place that's probably most of them.
I do suspect symbolic basis vectors is probably overkill. That sounds like it'd be a lot of new abstractions just for this one use case. I feel like the best way to tackle this is to improve the compiler -- either in XLA:CPU (make the current XLA:GPU optimisation backend-agnostic) or to start improving the facilities around jaxpr manipulation. There's already DCE in a few places for jaxprs, maybe it's time to start adding pattern matching as well?
As for applications: absolutely I have applications! So this scenario appears when solving differential equations. The mathematics is pretty interesting so I'll write out a bit of it as I reckon you'll find it interesting too.
When you make a step of e.g. Euler's method on an ODE dy(t) = f(y(t)) dt
then this is implemented as y_{n+1} = y_n + f(y_n) (t_{n+1} - t_n)
. That interaction between f(y_n)
and t_{n+1} - t_n
is multiplication. But when solving an SDE dy(t) = f(y(t)) dw(t)
then we end up writing down y_{n+1} = y_n + f(y_n) (w_{n+1} - w_n)
where now the analogous interaction is a matrix-vector product. If I wrote down the more general SDE dy(t) = [f(y(t)), g(y(t))] . [dt, dw(t)]
then the interaction is a "dot product" which is a multiplication in one element and a matrix-vector product in the other element. In general, the interaction is a bilinear map.
The nature of this interaction is user-specified and unknown to the me (the library author) as it's a detail of the input differential equation. So at the level of the library it's an abstract bilinear map, which the diffeq solving code is generic with respect to. In Diffrax this is AbstractTerm.prod.
That works 99% of the time. But there are a few cases where we do need to materialise this abstract bilinear map explicitly.
Some solvers need this when backpropagating via the backwards-in-time continuous adjoint ODE. This is discussed in this comment and the Jacobian calculation occurs here.
Meanwhile another example is when using the Itô version of Milstein's method. This is actually a very complicated algorithm when you have abstract bilinear maps and PyTrees floating around. If you're feeling masochistic then the full algorithm is here. (And very heavily commented to be as readable as possible.) In fact this very issue is actually referenced here. The use of Jacobians to materialise the abstract (bi)linear maps occurs here and here.
JAX is unable to obtain the Jacobian of a linear function in an efficient manner.
On CPU with JAX version 0.2.26 and jaxlib version 0.1.75 I get the printout:
So despite the fact that the matrix
A
already exists, and that alljac
needs to do is return [a copy of] this existing matrix, it still takes two orders of magnitude too long to accomplish this task.I'm providing this as a "benchmark MWE" -- my concern is that if even this example fails, how much else is also silently taking far too long?
(On GPU with JAX version 0.2.20 and jaxlib version 0.1.71 I don't get a printout at all -- compilation time exceeds several minutes. I've confirmed that the install is otherwise working. I don't know what's going on with that either.)
Various notes:
On GPU the compilation happens in a reasonable amount of time if the size
10000
is reduced to e.g.10
. I'm a little surprised that compilation time isn't just O(1) in the size?The CPU compilation time of 5.86 seconds seems surprisingly long as well.
For context I'm doing the above procedure -- using
jax.jacfwd
to obtain the matrix representation of a linear function -- on every step of a differential equation solver, so the above issue is hit repeatedly.The benchmark version uses
jnp.array
to perform a copy, for fairness. (I don't think the benchmark does anything clever -- e.g. copy elision via copy-on-write -- as otherwise it would be even faster -- so I think this is a fair comparison.)Here are the results of
jax.make_jaxpr(jac)(y)
: