jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.63k stars 2.82k forks source link

JIT Compilation has cubic or quadratic asymptotic complexity on certain programs #18787

Open nshepperd opened 1 year ago

nshepperd commented 1 year ago

Description

See the attached test case. Compiling a relatively simple function such as the below:

    def f(x, y):
        t = x * y
        for _ in range(n):
            t += x * y
        return t

jax.jit(f).lower(x,y).compile() takes time that increases roughly proportionally to the third power of n. Lowering without compiling is fast. With an x = _optimization_barrier(x) inserted in the loop, the time complexity appears instead to be quadratic, which I think suggests this is an issue with xla (poorly optimized hlo passes)?

I don't think this is just an academic problem, as I have had compiles just fail to ever terminate (within the time allowed by my patience) before, which is unsurprising when these runtimes are extrapolated to n=10k or more. In those cases i always just adjusted my code until it was able to mysteriously compile again.

I think compile times with quadratic or more dependence on program size should probably be considered a bug, seeing how people are frequently using models with very large numbers of ops. In my case this seems to be problematic since it is possibly causing programs generated by an experimental rematerialization strategy I'm working on to be unfeasible to compile.

### test_cubic.py
import jax
import jaxlib
import jax.numpy as jnp
from jax._src.ad_checkpoint import _optimization_barrier
import time

def test_cubic(n, barrier=False):
    def f(x, y):
        t = x * y
        for _ in range(n):
            if barrier:
                x = _optimization_barrier(x)
            t += x * y
        return t
    x = jnp.ones([1,1])
    y = jnp.ones([1,1])
    jax.jit(f).lower(x,y).compile()

if __name__ == '__main__':
    print(jax.__version__, jaxlib.__version__)

    print('Cubic, w/o _optimization_barrier:')
    n = 250
    duration = 0
    while duration < 60:
        start = time.time()
        test_cubic(n)
        duration = time.time() - start
        print(f'{n}: {duration} seconds')
        n *= 2

    print('Cubic, with _optimization_barrier:')
    n = 250
    duration = 0
    while duration < 60:
        start = time.time()
        test_cubic(n, True)
        duration = time.time() - start
        print(f'{n}: {duration} seconds')
        n *= 2

# 0.4.20 0.4.20
# Cubic, w/o _optimization_barrier:
# 250: 0.7570211887359619 seconds
# 500: 1.8215487003326416 seconds
# 1000: 11.496224164962769 seconds
# 2000: 86.28091049194336 seconds
# Cubic, with _optimization_barrier:
# 250: 0.5373725891113281 seconds
# 500: 1.045823335647583 seconds
# 1000: 2.2935471534729004 seconds
# 2000: 5.252176523208618 seconds
# 4000: 13.146437883377075 seconds
# 8000: 49.839492082595825 seconds
# 16000: 190.65447235107422 seconds

What jax/jaxlib version are you using?

jax v0.4.20, jaxlib v0.4.20

Which accelerator(s) are you using?

GPU

Additional system info?

1.26.2; 3.11.6 (main, Nov 14 2023, 09:36:21) [GCC 13.2.1 20230801] uname_result(system='Linux', node='phenex', release='6.6.3-arch1-1', version='#1 SMP PREEMPT_DYNAMIC Wed, 29 Nov 2023 00:37:40 +0000', machine='x86_64')

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.29.06              Driver Version: 545.29.06    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3090        Off | 00000000:2D:00.0  On |                  N/A |
| 61%   44C    P8              32W / 350W |  19516MiB / 24576MiB |      4%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
andportnoy commented 1 year ago
for _ in range(n):
   ...

Keep in mind pure Python loops get statically unrolled, so what XLA receives as input is a "concatenation" of n bodies of the loop above. This section of the JAX Sharp Bits page has more context: Python control flow + JIT. See also this section on JAX control flow primitives that are built on top of XLA-native constructs: Structured control flow primitives. The entire page is worth reading.

jakevdp commented 12 months ago

Hi @nshepperd – I think this is working more-or-less as expected. As @andportnoy mentioned, Python control flow is statically unrolled in JAX programs, and in general XLA compilation cost is roughly quadratic in the unrolled size of the program. Because of this, a key recommendation in using JAX effectively is to avoid large Python loops within transformations like JIT. So, for example, if you're looping over operations on array values, you can often re-express your computation in terms of vmap or other vectorized expressions. If you're looping over steps in a pipeline, you might use fori_loop. If you're looping over e.g. expensive training steps (where the runtime of the loop body is much longer than the typical operation dispatch time), then a useful pattern is to JIT-compile just the loop body, and iterate outside JIT over calls to that function.

Hope that helps!

nshepperd commented 12 months ago

Yes, I'm aware of vmap/scan/etc. The loop being unrolled is intentional, as it is a straightforward way to make a computation graph with a certain number of nodes. I don't believe the compilation runtime should be cubic on any program since this can easily make large model evaluations unfeasible.

nshepperd commented 12 months ago

I also object to compilation runtimes being quadratic but accept that improving the runtime to O(n log n) is most likely quite a bit more nontrivial.

nshepperd commented 12 months ago

Well I figured out why this is taking so long with a bit of stack trace sampling. GpuInstructionFusion is extremely inefficient on this case. In particular, InstructionFusion iterates over all ops (which is expected), looking for fusible ops, but when checking the fusibility for each op calls xla::gpu::SharedMemoryUsage, which is itself O(n^2) in the size of any HloOpcode::kFusion op encountered (which increases over time to the full size of the program in this case where the whole thing can be fused, because it's all elementwise ops).

The reason xla::gpu::SharedMemoryUsage is quadratic is because it enumerates the instructions inside the fused computation op, and invokes FindNonTrivialHero for each of them, which itself does a bfs search over the same instructions. This seems very easily fixable just by memoizing the bfs within FindNonTrivialHero for this case.

(Gpu)InstructionFusion as a whole seems like it should be doable in O(n) by computing SharedMemoryUsage more carefully and updating it if necessary.

jakevdp commented 12 months ago

Thanks for digging-in: for what it's worth, all of the xla compilation-related code is maintained in https://github.com/openxla/xla. You might have some luck raising an issue there.

hawkinsp commented 11 months ago

I agree this is certainly an XLA bug. O(n^2) algorithms in compilation are not acceptable.

hawkinsp commented 11 months ago

Filed https://github.com/openxla/xla/issues/7971

cheshire commented 3 months ago

SharedMemoryUsage was specifically written with memoization cache in mind, cf. https://github.com/openxla/xla/blob/main/xla/service/gpu/gpu_fusible.cc#L731-L735 .

I wonder is it not triggering? Or because a new fusion invalidates the search?

cheshire commented 3 months ago

@olegshyshkov WDYT?

olegshyshkov commented 3 months ago

We didn't use cache for SharedMemoryUsage in PriorityFusion until recently. I fixed it in https://github.com/openxla/xla/commit/5b0e626589b08de741c14c2568cf1b942ace55ba. Can't say anything about GpuInstructionFusion, it shouldn't be used anymore.

cheshire commented 3 months ago

@nshepperd could you verify that the bug is fixed now?