Open TonyZhou729 opened 7 months ago
Just a note that I just ran this example with jax==0.4.34 and the issue persists
Sorry I missed this originally – first of all, when you are running these kinds of microbenchmarks, be sure to follow the recommendations at FAQ: Benchmarking JAX code. In particular, you should wrap the computation of interest in jax.block_until_ready
to ensure you're measuring computation time rather than just dispatch time:
for i in range(5):
s = time.time()
B = jax.block_until_ready(main())
print(time.time() - s)
Still, even with this I'm seeing the same general behavior you are. It looks like something about adding the cond
leads the XLA compiler to use different fusions which lead to different computation characteristics. You can see this by using Ahead-of-time compilation tools to output the compiled HLO:
print(main.lower().compile().as_text())
For case 1, the output is this:
Description
Hi,
We noticed a strange speed-up when a trivial lax.cond statement is used to call a function rather than directly calling a function itself.
In the reproduction of the issue below, we use JIT on a main() function which contains a lax.scan() loop. In each loop call, if we insert a lax.cond() around the function we call with the condition that the loop index i (runs from 0 to Ny for Ny steps) is greater than -1, which is always true. This seemingly unnecessary choice somehow causes a speed up.
When using case 1 in loop_in_main() and calling main() 5 times we observe runtimes of (in seconds)
But switching to case 2 we see
In both cases the first run time is longer due to JIT compilation. We checked that this speed up scales with Ny, the number of steps in lax.scan. In our code with more computations in each step the speed up is even more significant.
Thank you in advance for your help and comments! Tony
System info (python version, jaxlib version, accelerator, etc.)