jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.33k stars 2.78k forks source link

Function cache for `grad` #2095

Open tomhennigan opened 4 years ago

tomhennigan commented 4 years ago

A JAX user I am working with was confused by the performance of the following in colab:

x = jnp.ones([5000,5000])
f = lambda x: jnp.mean(jnp.dot(x, x.T))

%timeit -n10 -r3 grad(jit(f))(x).block_until_ready()
# 10 loops, best of 3: 40 ms per loop

%timeit -n10 -r3 jit(grad(f))(x).block_until_ready()
# 10 loops, best of 3: 935 ms per loop

We can fix this by caching the transformed function, but it caused some confusion why some transformations cache the result per function and others do not:

jgf = jit(grad(f))
%timeit -n10 -r3 jgf(x).block_until_ready()
# 10 loops, best of 3: 26 ms per loop

Would it be reasonable to cache the higher order function returned by grad(f) inside JAX (aligned with jit)?

j-towns commented 4 years ago

Another related thing that surprised one of my colleagues is the fact that

%timeit jit(jit(f))(x)

is slower than

%timeit jit(f)(x)

Putting a @memoize decorator on jax.api.jit and jax.api.grad results in less surprising timings:

In [3]: x = jnp.ones([500,500])
   ...: f = lambda x: jnp.mean(jnp.dot(x, x.T))
   ...:
   ...: %timeit -n10 -r3 grad(jit(f))(x).block_until_ready()
   ...: %timeit -n10 -r3 jit(grad(f))(x).block_until_ready()
13.2 ms ± 8.21 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
5.66 ms ± 1.84 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

After inserting the @memoize there are some test failures from static_args being passed as an (unhashable) list to jit, and a few other things, though I expect they wouldn't be too difficult to fix.

If we're going make this change for jit and grad we might as well do it for all of the function transformations. As an alternative, we could add a note to the docs to explain the current behaviour, and that you shouldn't compose JAX transformations inside a loop. However I'd be getting a bit worried that we're going to confuse users by saying "some stuff is cached/memoized and some stuff isn't". It might be a lot clearer to just be able to say "JAX memoizes everything it can"...

mattjj commented 4 years ago

@j-towns can you unpack why jit(jit(f))(x) is slower?

mattjj commented 4 years ago

This does sound like a reasonable idea...

j-towns commented 4 years ago

@j-towns can you unpack why jit(jit(f))(x) is slower?

jit wraps f using linear_util.wrap, then passes it down into jax.interpreters.xla where it ends up being passed to xla._xla_callable which has a linear_util.cache decorator. linear_util.cache uses the function identity and the transformations applied to it as the cache key, so if you do

jit(f)(x)
jit(f)(x)

the second jit call will use the memoized result of the first internally, because xla._xla_callable has the cache decorator. The result of the second jit(f) will not however be identical to the first, because at the top level jit re-wraps f, i.e. it returns a new function object even when it's internally re-using the compiled version of f.

Therefore

In [7]: jit(f) is jit(f)
Out[7]: False

That's why if you do

jit(jit(f))(x)
jit(jit(f))(x)

then when xla._xla_callable is called within the outer jit of the second jit(jit(f)) there isn't a cache hit...

j-towns commented 4 years ago

I think whatever approach we end up with we should probably write some notes about the behaviour in the docs somewhere, cos this seems to be a common pitfall for new JAXers.

josephrocca commented 3 years ago

However I'd be getting a bit worried that we're going to confuse users by saying "some stuff is cached/memoized and some stuff isn't". It might be a lot clearer to just be able to say "JAX memoizes everything it can"

Yep, I think it'd be a good idea to cache all function transformations where it's possible and makes sense. I got tripped up by this here (which originated from this issue) - I was used to writing jit(f)(x) and so I assumed that I could write jit(grad(f))(x), but the latter requires recompilation every time (obvious in hindsight, knowing that grad doesn't cache).

lukemetz commented 2 years ago

Has there been any updates on this? I am working with an example which is ~30x slower with half of the time in trace_to_subjaxpr_nounits and half of that in a toposort each iteration. Being able to cache more of these computations would be awesome!