Open grjzwaan opened 2 years ago
It seems related to the size of the array 🤔 . If I increase N = 4 * 1024
then
# scan_compiled(a) < -- doesn't help
%timeit scan_compiled(a)
95.5 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
.lower(...).compile()
uses a less optimized dispatch code path than the regular jit
path, which we've optimized heavily.
Thanks!
Out of interest: This means that the resulting compiled function is the same, but the interface on calling the actual function is different (i.e. less optimized)?
Yes, exactly. The code that checks that the arguments you're passing are consistent with what the executable expects is less optimized in the AOT API, which is significantly less mature than the other one.
I'm seeing a big difference in computation time between
jax.jit()
andjax.jit().lower().compile()
. Jitting and then executing is faster than precompiling. I expect that performance should be the same.Perhaps I overlooked something, but I cannot explain the differences.
This code is run on the CPU:
✔️ Time the function that is not jitted:
%timeit scan(a)
✔️ Jitting speeds it up
%timeit jax.jit(scan)(a)
✔️ First jitting then executing it once:
❌ Precompiling then executing?