google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.1k stars 2.66k forks source link

Precompile is slower than regular jit #9503

Open grjzwaan opened 2 years ago

grjzwaan commented 2 years ago

I'm seeing a big difference in computation time between jax.jit() and jax.jit().lower().compile(). Jitting and then executing is faster than precompiling. I expect that performance should be the same.

Perhaps I overlooked something, but I cannot explain the differences.

This code is run on the CPU:

import jax
import jax.numpy as np

# Fake function
def scan(a):

    b0 = 5.

    def update(b_prev, a):
        b_next = np.log(a) + b_prev
        return b_next, a * np.sqrt(b_prev)

    b_last, b = jax.lax.scan(update, b0, a)

    return b

N = 1000
dummy = np.ones((N,))
a = np.ones((N,)) * 2.   # 'real' data

✔️ Time the function that is not jitted: %timeit scan(a)

32.5 ms ± 340 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

✔️ Jitting speeds it up %timeit jax.jit(scan)(a)

38.4 µs ± 381 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

✔️ First jitting then executing it once:

scan_jitted = jax.jit(scan)
scan_jitted(a)
%timeit scan_jitted(a)

12.7 µs ± 55.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

❌ Precompiling then executing?

scan_compiled = jax.jit(scan).lower(a).compile()
# scan_compiled(a) < -- doesn't help
%timeit scan_compiled(a)

35.5 µs ± 506 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

grjzwaan commented 2 years ago

It seems related to the size of the array 🤔 . If I increase N = 4 * 1024 then

# scan_compiled(a) < -- doesn't help
%timeit scan_compiled(a)

95.5 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

hawkinsp commented 2 years ago

.lower(...).compile() uses a less optimized dispatch code path than the regular jit path, which we've optimized heavily.

grjzwaan commented 2 years ago

Thanks!

Out of interest: This means that the resulting compiled function is the same, but the interface on calling the actual function is different (i.e. less optimized)?

apaszke commented 2 years ago

Yes, exactly. The code that checks that the arguments you're passing are consistent with what the executable expects is less optimized in the AOT API, which is significantly less mature than the other one.