Open tomhennigan opened 4 years ago
Another related thing that surprised one of my colleagues is the fact that
%timeit jit(jit(f))(x)
is slower than
%timeit jit(f)(x)
Putting a @memoize
decorator on jax.api.jit
and jax.api.grad
results in less surprising timings:
In [3]: x = jnp.ones([500,500])
...: f = lambda x: jnp.mean(jnp.dot(x, x.T))
...:
...: %timeit -n10 -r3 grad(jit(f))(x).block_until_ready()
...: %timeit -n10 -r3 jit(grad(f))(x).block_until_ready()
13.2 ms ± 8.21 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
5.66 ms ± 1.84 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
After inserting the @memoize
there are some test failures from static_args being passed as an (unhashable) list to jit, and a few other things, though I expect they wouldn't be too difficult to fix.
If we're going make this change for jit and grad we might as well do it for all of the function transformations. As an alternative, we could add a note to the docs to explain the current behaviour, and that you shouldn't compose JAX transformations inside a loop. However I'd be getting a bit worried that we're going to confuse users by saying "some stuff is cached/memoized and some stuff isn't". It might be a lot clearer to just be able to say "JAX memoizes everything it can"...
@j-towns can you unpack why jit(jit(f))(x)
is slower?
This does sound like a reasonable idea...
@j-towns can you unpack why
jit(jit(f))(x)
is slower?
jit
wraps f
using linear_util.wrap
, then passes it down into jax.interpreters.xla
where it ends up being passed to xla._xla_callable
which has a linear_util.cache
decorator. linear_util.cache
uses the function identity and the transformations applied to it as the cache key, so if you do
jit(f)(x)
jit(f)(x)
the second jit call will use the memoized result of the first internally, because xla._xla_callable
has the cache decorator. The result of the second jit(f) will not however be identical to the first, because at the top level jit re-wraps f, i.e. it returns a new function object even when it's internally re-using the compiled version of f.
Therefore
In [7]: jit(f) is jit(f)
Out[7]: False
That's why if you do
jit(jit(f))(x)
jit(jit(f))(x)
then when xla._xla_callable
is called within the outer jit
of the second jit(jit(f))
there isn't a cache hit...
I think whatever approach we end up with we should probably write some notes about the behaviour in the docs somewhere, cos this seems to be a common pitfall for new JAXers.
However I'd be getting a bit worried that we're going to confuse users by saying "some stuff is cached/memoized and some stuff isn't". It might be a lot clearer to just be able to say "JAX memoizes everything it can"
Yep, I think it'd be a good idea to cache all function transformations where it's possible and makes sense. I got tripped up by this here (which originated from this issue) - I was used to writing jit(f)(x)
and so I assumed that I could write jit(grad(f))(x)
, but the latter requires recompilation every time (obvious in hindsight, knowing that grad
doesn't cache).
Has there been any updates on this? I am working with an example which is ~30x slower with half of the time in trace_to_subjaxpr_nounits
and half of that in a toposort
each iteration. Being able to cache more of these computations would be awesome!
A JAX user I am working with was confused by the performance of the following in colab:
We can fix this by caching the transformed function, but it caused some confusion why some transformations cache the result per function and others do not:
Would it be reasonable to cache the higher order function returned by
grad(f)
inside JAX (aligned withjit
)?