google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.75k stars 2.71k forks source link

Accidentally exponential-time analysis pass in compiler? #22385

Open kach opened 1 month ago

kach commented 1 month ago

Description

JAX is unexpectedly slow to compile this simple recursive Fibonacci function.

My understanding is that it should cache each JIT-compilation of fib with a fresh (static) t, so I would expect compilation to take linear time in t, even though at runtime a call to fib would take exponential time to compute.

Indeed, at compile-time I only see each print statement called once per t. Yet nonetheless, JAX seems to be taking exponential time in t to compile fib — and the print statements "slow down" over time.

Hence, I suspect there is an analysis pass in the compiler that is taking exponential time.

import functools
import jax

@functools.partial(jax.jit, static_argnums=(0,))
def fib(t):
    print(f"Starting tracing fib({t})...")
    out = 1 if t <= 2 else fib(t - 1) + fib(t - 2)
    print(f"Finished tracing fib({t}).")
    return out

print(fib(24))

Another interesting piece of data:

print(fib.lower(10).as_text())  # this grows linearly with t
print(jax.make_jaxpr(fib, static_argnums=(0,))(10))  # this grows exponentially with t

System info (python version, jaxlib version, accelerator, etc.)

(venv) $ python --version
Python 3.12.4
(venv) $ python
Python 3.12.4 (main, Jun  6 2024, 18:26:44) [Clang 15.0.0 (clang-1500.1.0.2.5)] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax; jax.print_environment_info()
jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0
python: 3.12.4 (main, Jun  6 2024, 18:26:44) [Clang 15.0.0 (clang-1500.1.0.2.5)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='*****', release='22.6.0', version='Darwin Kernel Version 22.6.0: Tue Nov  7 21:42:27 PST 2023; machine='arm64')
kach commented 1 month ago

(@abadams adds that it gets stuck in _jaxpr_forwarding.)

jakevdp commented 1 month ago

Thanks for the question – this is expected behavior.

For Python-side control flow (including for loops, if statements, and recursive function calls) JAX's tracing approach flattens the program and sends this flat list of instructions to the compiler. Your recursion here is essentially a very compact way of generating exponentially growing programs – the best fix here would be to avoid recursion, so as to avoid those exponentially growing programs.

jakevdp commented 1 month ago

If you're looking for an accelerator-friendly way to compute fibonacci numbers, you could use the matrix exponentiation method:

import jax
import jax.numpy as jnp

def fib(n):
  return jnp.linalg.matrix_power(jnp.array([[1, 1], [1, 0]]), n)[0, 1]
print([int(fib(i)) for i in range(10)])
# [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
kach commented 1 month ago

Hi Jake - Thanks for writing. I understand that JAX produces a flat trace for loops, conditionals, and function calls — but is that even true for calls to already @jax.jit-ed functions?

I was hoping that because I wrapped fib in jax.jit with t as a static argument, JAX would recursively (1) JIT fib for t-1 and t-2, and then (2) emit function calls to the already-JITted code for fib(t-1) and fib(t-2) when compiling fib(t). This would net produce a linear amount of code in t. Indeed, the lowered HLO does look like that:

module @jit_fib attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main() -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @fib() : () -> tensor<i32>
    %1 = call @fib_0() : () -> tensor<i32>
    %2 = stablehlo.add %0, %1 : tensor<i32>
    return %2 : tensor<i32>
  }
  func.func private @fib() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
    %0 = call @fib_0() : () -> tensor<i32>
    %1 = call @fib_1() : () -> tensor<i32>
    %2 = stablehlo.add %0, %1 : tensor<i32>
    return %2 : tensor<i32>
  }
  func.func private @fib_0() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
    %0 = call @fib_1() : () -> tensor<i32>
    %1 = call @fib_2() : () -> tensor<i32>
    %2 = stablehlo.add %0, %1 : tensor<i32>
    return %2 : tensor<i32>
  }
  func.func private @fib_1() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
    %0 = call @fib_2() : () -> tensor<i32>
    %1 = call @fib_3() : () -> tensor<i32>
    %2 = stablehlo.add %0, %1 : tensor<i32>
    return %2 : tensor<i32>
  }
  func.func private @fib_2() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
    %0 = call @fib_3() : () -> tensor<i32>
    %1 = call @fib_4() : () -> tensor<i32>
    %2 = stablehlo.add %0, %1 : tensor<i32>
    return %2 : tensor<i32>
  }
  func.func private @fib_3() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
    %0 = call @fib_4() : () -> tensor<i32>
    %1 = call @fib_5() : () -> tensor<i32>
    %2 = stablehlo.add %0, %1 : tensor<i32>
    return %2 : tensor<i32>
  }
  func.func private @fib_4() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
    %0 = call @fib_5() : () -> tensor<i32>
    %1 = call @fib_6() : () -> tensor<i32>
    %2 = stablehlo.add %0, %1 : tensor<i32>
    return %2 : tensor<i32>
  }
  func.func private @fib_5() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
    %0 = call @fib_6() : () -> tensor<i32>
    %1 = call @fib_7() : () -> tensor<i32>
    %2 = stablehlo.add %0, %1 : tensor<i32>
    return %2 : tensor<i32>
  }
  func.func private @fib_6() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<1> : tensor<i32>
    return %c : tensor<i32>
  }
  func.func private @fib_7() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<1> : tensor<i32>
    return %c : tensor<i32>
  }
}

so I'm confused why the intermediate JAXprs explode. Does that make sense?

By the way, I'm just using fib as an example here. What I'm actually trying to do is to implementing dynamic programming algorithms like value iteration in JAX — via memo, a new DSL that compiles to JAX https://github.com/kach/memo.

jakevdp commented 1 month ago

I understand that JAX produces a flat trace for loops, conditionals, and function calls — but is that even true for calls to already @jax.jit-ed functions?

Yes, as evidenced by the exponentially growing jaxpr

so I'm confused why the intermediate JAXprs explode.

Unlike HLO, jaxprs don't have subroutines, so the logic of your function will always be laid-out as an explicit linear sequence of operations. In theory perhaps jaxprs could be generalized to allow for reusable subroutines like in HLO, but that would run counter to their current design philosophy.

Essentially, JAX is not designed to work well with deeply recursive implementations like the one you wrote here. Your best bet would be to find a different way to express the logic your program.

hawkinsp commented 1 month ago

Hmm, but the jaxpr shouldn't be exponential in size. We cache subjaxprs based on object identity and attempt to deduplicate when printing them. It's possible this is an artifact of the jaxpr printer not doing a great job of deduplication though.

hawkinsp commented 1 month ago

In fact I strongly suspect the jaxpr size blow up must be an artifact of the pretty-printer, because otherwise the lowered stablehlo would be exponential in size.

kach commented 1 month ago

@jakevdp I guess I'm still confused about where the nice linear HLO comes from, if the intermediate jaxpr throws away the function-calling structure.

@hawkinsp I'm not sure it's just the pretty-printer. The main symptom I was observing was that even without pretty-printing, the compile time was growing exponentially.

hawkinsp commented 1 month ago

Yes. The compiler will inline at the moment. That may improve at some point.

JAX doesn't throw away the function structure, but XLA inlines, leading to a large compile time.