Open nshepperd opened 1 year ago
for _ in range(n): ...
Keep in mind pure Python loops get statically unrolled, so what XLA receives as input is a "concatenation" of
n
bodies of the loop above. This section of the JAX Sharp Bits page has more context: Python control flow + JIT. See also this section on JAX control flow primitives that are built on top of XLA-native constructs: Structured control flow primitives. The entire page is worth reading.
Hi @nshepperd – I think this is working more-or-less as expected. As @andportnoy mentioned, Python control flow is statically unrolled in JAX programs, and in general XLA compilation cost is roughly quadratic in the unrolled size of the program. Because of this, a key recommendation in using JAX effectively is to avoid large Python loops within transformations like JIT. So, for example, if you're looping over operations on array values, you can often re-express your computation in terms of vmap
or other vectorized expressions. If you're looping over steps in a pipeline, you might use fori_loop
. If you're looping over e.g. expensive training steps (where the runtime of the loop body is much longer than the typical operation dispatch time), then a useful pattern is to JIT-compile just the loop body, and iterate outside JIT over calls to that function.
Hope that helps!
Yes, I'm aware of vmap/scan/etc. The loop being unrolled is intentional, as it is a straightforward way to make a computation graph with a certain number of nodes. I don't believe the compilation runtime should be cubic on any program since this can easily make large model evaluations unfeasible.
I also object to compilation runtimes being quadratic but accept that improving the runtime to O(n log n) is most likely quite a bit more nontrivial.
Well I figured out why this is taking so long with a bit of stack trace sampling. GpuInstructionFusion is extremely inefficient on this case. In particular, InstructionFusion iterates over all ops (which is expected), looking for fusible ops, but when checking the fusibility for each op calls xla::gpu::SharedMemoryUsage, which is itself O(n^2) in the size of any HloOpcode::kFusion op encountered (which increases over time to the full size of the program in this case where the whole thing can be fused, because it's all elementwise ops).
The reason xla::gpu::SharedMemoryUsage is quadratic is because it enumerates the instructions inside the fused computation op, and invokes FindNonTrivialHero for each of them, which itself does a bfs search over the same instructions. This seems very easily fixable just by memoizing the bfs within FindNonTrivialHero for this case.
(Gpu)InstructionFusion as a whole seems like it should be doable in O(n) by computing SharedMemoryUsage more carefully and updating it if necessary.
Thanks for digging-in: for what it's worth, all of the xla compilation-related code is maintained in https://github.com/openxla/xla. You might have some luck raising an issue there.
I agree this is certainly an XLA bug. O(n^2) algorithms in compilation are not acceptable.
SharedMemoryUsage was specifically written with memoization cache in mind, cf. https://github.com/openxla/xla/blob/main/xla/service/gpu/gpu_fusible.cc#L731-L735 .
I wonder is it not triggering? Or because a new fusion invalidates the search?
@olegshyshkov WDYT?
We didn't use cache for SharedMemoryUsage
in PriorityFusion until recently. I fixed it in https://github.com/openxla/xla/commit/5b0e626589b08de741c14c2568cf1b942ace55ba. Can't say anything about GpuInstructionFusion
, it shouldn't be used anymore.
@nshepperd could you verify that the bug is fixed now?
Description
See the attached test case. Compiling a relatively simple function such as the below:
jax.jit(f).lower(x,y).compile()
takes time that increases roughly proportionally to the third power of n. Lowering without compiling is fast. With anx = _optimization_barrier(x)
inserted in the loop, the time complexity appears instead to be quadratic, which I think suggests this is an issue with xla (poorly optimized hlo passes)?I don't think this is just an academic problem, as I have had compiles just fail to ever terminate (within the time allowed by my patience) before, which is unsurprising when these runtimes are extrapolated to n=10k or more. In those cases i always just adjusted my code until it was able to mysteriously compile again.
I think compile times with quadratic or more dependence on program size should probably be considered a bug, seeing how people are frequently using models with very large numbers of ops. In my case this seems to be problematic since it is possibly causing programs generated by an experimental rematerialization strategy I'm working on to be unfeasible to compile.
What jax/jaxlib version are you using?
jax v0.4.20, jaxlib v0.4.20
Which accelerator(s) are you using?
GPU
Additional system info?
1.26.2; 3.11.6 (main, Nov 14 2023, 09:36:21) [GCC 13.2.1 20230801] uname_result(system='Linux', node='phenex', release='6.6.3-arch1-1', version='#1 SMP PREEMPT_DYNAMIC Wed, 29 Nov 2023 00:37:40 +0000', machine='x86_64')
NVIDIA GPU info