Open kach opened 2 months ago
(@abadams adds that it gets stuck in _jaxpr_forwarding
.)
Thanks for the question – this is expected behavior.
For Python-side control flow (including for
loops, if
statements, and recursive function calls) JAX's tracing approach flattens the program and sends this flat list of instructions to the compiler. Your recursion here is essentially a very compact way of generating exponentially growing programs – the best fix here would be to avoid recursion, so as to avoid those exponentially growing programs.
If you're looking for an accelerator-friendly way to compute fibonacci numbers, you could use the matrix exponentiation method:
import jax
import jax.numpy as jnp
def fib(n):
return jnp.linalg.matrix_power(jnp.array([[1, 1], [1, 0]]), n)[0, 1]
print([int(fib(i)) for i in range(10)])
# [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
Hi Jake - Thanks for writing. I understand that JAX produces a flat trace for loops, conditionals, and function calls — but is that even true for calls to already @jax.jit
-ed functions?
I was hoping that because I wrapped fib
in jax.jit
with t
as a static argument, JAX would recursively (1) JIT fib
for t-1
and t-2
, and then (2) emit function calls to the already-JITted code for fib(t-1)
and fib(t-2)
when compiling fib(t)
. This would net produce a linear amount of code in t
. Indeed, the lowered HLO does look like that:
module @jit_fib attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main() -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = call @fib() : () -> tensor<i32>
%1 = call @fib_0() : () -> tensor<i32>
%2 = stablehlo.add %0, %1 : tensor<i32>
return %2 : tensor<i32>
}
func.func private @fib() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
%0 = call @fib_0() : () -> tensor<i32>
%1 = call @fib_1() : () -> tensor<i32>
%2 = stablehlo.add %0, %1 : tensor<i32>
return %2 : tensor<i32>
}
func.func private @fib_0() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
%0 = call @fib_1() : () -> tensor<i32>
%1 = call @fib_2() : () -> tensor<i32>
%2 = stablehlo.add %0, %1 : tensor<i32>
return %2 : tensor<i32>
}
func.func private @fib_1() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
%0 = call @fib_2() : () -> tensor<i32>
%1 = call @fib_3() : () -> tensor<i32>
%2 = stablehlo.add %0, %1 : tensor<i32>
return %2 : tensor<i32>
}
func.func private @fib_2() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
%0 = call @fib_3() : () -> tensor<i32>
%1 = call @fib_4() : () -> tensor<i32>
%2 = stablehlo.add %0, %1 : tensor<i32>
return %2 : tensor<i32>
}
func.func private @fib_3() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
%0 = call @fib_4() : () -> tensor<i32>
%1 = call @fib_5() : () -> tensor<i32>
%2 = stablehlo.add %0, %1 : tensor<i32>
return %2 : tensor<i32>
}
func.func private @fib_4() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
%0 = call @fib_5() : () -> tensor<i32>
%1 = call @fib_6() : () -> tensor<i32>
%2 = stablehlo.add %0, %1 : tensor<i32>
return %2 : tensor<i32>
}
func.func private @fib_5() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
%0 = call @fib_6() : () -> tensor<i32>
%1 = call @fib_7() : () -> tensor<i32>
%2 = stablehlo.add %0, %1 : tensor<i32>
return %2 : tensor<i32>
}
func.func private @fib_6() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
%c = stablehlo.constant dense<1> : tensor<i32>
return %c : tensor<i32>
}
func.func private @fib_7() -> (tensor<i32> {mhlo.layout_mode = "default"}) {
%c = stablehlo.constant dense<1> : tensor<i32>
return %c : tensor<i32>
}
}
so I'm confused why the intermediate JAXprs explode. Does that make sense?
By the way, I'm just using fib
as an example here. What I'm actually trying to do is to implementing dynamic programming algorithms like value iteration in JAX — via memo, a new DSL that compiles to JAX https://github.com/kach/memo.
I understand that JAX produces a flat trace for loops, conditionals, and function calls — but is that even true for calls to already @jax.jit-ed functions?
Yes, as evidenced by the exponentially growing jaxpr
so I'm confused why the intermediate JAXprs explode.
Unlike HLO, jaxprs don't have subroutines, so the logic of your function will always be laid-out as an explicit linear sequence of operations. In theory perhaps jaxprs could be generalized to allow for reusable subroutines like in HLO, but that would run counter to their current design philosophy.
Essentially, JAX is not designed to work well with deeply recursive implementations like the one you wrote here. Your best bet would be to find a different way to express the logic your program.
Hmm, but the jaxpr shouldn't be exponential in size. We cache subjaxprs based on object identity and attempt to deduplicate when printing them. It's possible this is an artifact of the jaxpr printer not doing a great job of deduplication though.
In fact I strongly suspect the jaxpr size blow up must be an artifact of the pretty-printer, because otherwise the lowered stablehlo would be exponential in size.
@jakevdp I guess I'm still confused about where the nice linear HLO comes from, if the intermediate jaxpr throws away the function-calling structure.
@hawkinsp I'm not sure it's just the pretty-printer. The main symptom I was observing was that even without pretty-printing, the compile time was growing exponentially.
Yes. The compiler will inline at the moment. That may improve at some point.
JAX doesn't throw away the function structure, but XLA inlines, leading to a large compile time.
Description
JAX is unexpectedly slow to compile this simple recursive Fibonacci function.
My understanding is that it should cache each JIT-compilation of
fib
with a fresh (static)t
, so I would expect compilation to take linear time int
, even though at runtime a call tofib
would take exponential time to compute.Indeed, at compile-time I only see each print statement called once per
t
. Yet nonetheless, JAX seems to be taking exponential time int
to compilefib
— and the print statements "slow down" over time.Hence, I suspect there is an analysis pass in the compiler that is taking exponential time.
Another interesting piece of data:
System info (python version, jaxlib version, accelerator, etc.)