jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.46k stars 2.8k forks source link

Seemingly odd behaviour with make_jaxpr #15736

Closed JaySandesara closed 1 year ago

JaySandesara commented 1 year ago

Hello,

I am trying to understand exactly how auto-diff works with JAX. I have an example function:

def fnc_jax(x1, x2):

    return (jnp.divide(x1,x2) - jnp.exp(x2))*(jnp.sin(jnp.divide(x1,x2)) + jnp.divide(x1,x2) - jnp.exp(x2))

Now I do jax.make_jaxpr(fnc_jax)(1.0,1.0), which gives me the following output:

{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = div a b
    d:f32[] = exp b
    e:f32[] = sub c d
    f:f32[] = div a b
    g:f32[] = sin f
    h:f32[] = div a b
    i:f32[] = add g h
    j:f32[] = exp b
    k:f32[] = sub i j
    l:f32[] = mul e k
  in (l,) }

Question 1: Why is it that c, e, and h for example are doing the exact same computation? Wouldn't a better way be if we had something like this:

def fnc_jax_alt(x1, x2):
    a = x1/x2      
    b = np.exp(x2)     
    c = np.sin(a)
    d = a - b
    e = c + d
    g = d * e
    return g

which gives the same output for a given input as fnc_jax, but doesnt repeat computations confirmed using jax.make_jaxpr(fnc_jax_alt)(1.0,1.0). Is this a feature or a bug?

Another confusion I have is when I do jax.make_jaxpr(jax.grad(fnc_jax))(1.0,1.0):

{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = div a b
    d:f32[] = exp b
    e:f32[] = sub c d
    f:f32[] = div a b
    g:f32[] = sin f
    h:f32[] = cos f
    i:f32[] = div a b
    j:f32[] = add g i
    k:f32[] = exp b
    l:f32[] = sub j k
    _:f32[] = mul e l
    m:f32[] = mul e 1.0
    n:f32[] = mul 1.0 l
    o:f32[] = div m b
    p:f32[] = mul m h
    q:f32[] = div p b
    r:f32[] = add_any o q
    s:f32[] = div n b
    t:f32[] = add_any r s
  in (t,) }

Question 2: Apart from repeated computations, what confuses me is what exactly is going on here? It seems to me like it is doing something of a forward differentiation. But if I understand correctly from the documentation, jax.grad performs reverse-mode differentiation by default. Am I wrong or mis-understanding the make_jaxpr?

If it is forward-mode differentiation, then should we not have something like the following:

{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = div a b
    d:f32[] = exp b
    e:f32[] = sin c
    f:f32[] = sub c d
    g:f32[] = add e f
    _:f32[] = mul f g
    h:f32[] = div 1.0 b
    i:f32[] = cos c
    j:f32[] = mul i h
    k:f32[] = sub h 0.0
    l:f32[] = add j k
    m:f32[] = mul k g
    n:f32[] = mul f l
    o:f32[] = add m n
  in (o,) }

If it is reverse-mode differentiation, how can I get a make_jaxpr for forward-mode?

Apologies in advance for my mis-understandings. I would appreciate some clarity on these questions!

Regards, Jay Sandesara

What jax/jaxlib version are you using?

No response

Which accelerator(s) are you using?

No response

Additional system info

No response

NVIDIA GPU info

No response

soraros commented 1 year ago

@JaySandesara For your Q1: Eager mode (not jitted) JAX runs fnc_jax op-by-op, so indeed the computations are duplicated, and jax.make_jaxpr only reflects this fact. However, if you jit your function, CSE will kick-in, and divide only occurred once:

print(jit(fnc_jax).lower(1., 1.).compile().as_text())
HloModule jit_fnc_jax, entry_computation_layout={(f64[],f64[])->f64[]}, allow_spmd_sharding_propagation_to_output={true}

fused_computation {
  param_0.2 = f64[] parameter(0)
  param_1.4 = f64[] parameter(1)
  exponential.0 = f64[] exponential(param_1.4)
  subtract.1 = f64[] subtract(param_0.2, exponential.0)
  sine.0 = f64[] sine(param_0.2)
  add.0 = f64[] add(sine.0, param_0.2)
  subtract.0 = f64[] subtract(add.0, exponential.0)
  ROOT multiply.0 = f64[] multiply(subtract.1, subtract.0)
}

ENTRY main.13 {
  Arg_0.1 = f64[] parameter(0), sharding={replicated}
  Arg_1.2 = f64[] parameter(1), sharding={replicated}
  divide.3 = f64[] divide(Arg_0.1, Arg_1.2)
  ROOT fusion = f64[] fusion(divide.3, Arg_1.2), kind=kLoop, calls=fused_computation
}

Assuming your functions in JAX will be jitted, you generally don't need to worry about duplicate computation, and you can code it the way that looks cleanest/makes most sense conceptually.

jakevdp commented 1 year ago

Question 1: Why is it that c, e, and h for example are doing the exact same computation?

The reason for this is that jaxprs do not produce optimized code, they merely produce an intermediate representation of the code you wrote. Because you computed jnp.divide(x, y) three times, the computation is represented three times in the jaxpr.

Never fear, though: when you JIT-compile the code, the compiler recognizes these repeated operations and will de-duplicate them, freeing you from having to worry about doing so manually. You can see this by printing the compiled HLO, though it's admittedly harder to read than the jaxpr:

print(jax.jit(fnc_jax).lower(1.0, 1.0).compile().as_text())
%fused_computation (param_0.2: f32[], param_1.4: f32[]) -> f32[] {
  %param_0.2 = f32[] parameter(0)
  %param_1.4 = f32[] parameter(1)
  %exponential.0 = f32[] exponential(f32[] %param_1.4), metadata={op_name="jit(fnc_jax)/jit(main)/exp" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
  %subtract.1 = f32[] subtract(f32[] %param_0.2, f32[] %exponential.0), metadata={op_name="jit(fnc_jax)/jit(main)/sub" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
  %sine.0 = f32[] sine(f32[] %param_0.2), metadata={op_name="jit(fnc_jax)/jit(main)/sin" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
  %add.0 = f32[] add(f32[] %sine.0, f32[] %param_0.2), metadata={op_name="jit(fnc_jax)/jit(main)/add" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
  %subtract.0 = f32[] subtract(f32[] %add.0, f32[] %exponential.0), metadata={op_name="jit(fnc_jax)/jit(main)/sub" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
  ROOT %multiply.0 = f32[] multiply(f32[] %subtract.1, f32[] %subtract.0), metadata={op_name="jit(fnc_jax)/jit(main)/mul" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
}

ENTRY %main.13 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] {
  %Arg_0.1 = f32[] parameter(0), sharding={replicated}
  %Arg_1.2 = f32[] parameter(1), sharding={replicated}
  %divide.3 = f32[] divide(f32[] %Arg_0.1, f32[] %Arg_1.2), metadata={op_name="jit(fnc_jax)/jit(main)/div" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
  ROOT %fusion = f32[] fusion(f32[] %divide.3, f32[] %Arg_1.2), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(fnc_jax)/jit(main)/mul" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
}

Question 2: Apart from repeated computations, what confuses me is what exactly is going on here?

jax.grad is implemented in terms of reverse-mode differentiation, so the generated jaxpr reflects the gradient computed in reverse mode. There's no exact forward-mode equivalent of jax.grad, but you can get roughly the same thing by computing the jaxpr of jax.jacfwd:

print(jax.make_jaxpr(jax.jacfwd(fnc_jax))(1.0,1.0))
{ lambda ; a:f32[] b:f32[]. let
    c:i32[1,1] = iota[dimension=0 dtype=int32 shape=(1, 1)] 
    d:i32[1,1] = add c 0
    e:i32[1,1] = iota[dimension=1 dtype=int32 shape=(1, 1)] 
    f:bool[1,1] = eq d e
    g:f32[1,1] = convert_element_type[new_dtype=float32 weak_type=False] f
    h:f32[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] g
    i:f32[1] = reshape[dimensions=None new_sizes=(1,)] h
    j:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] i
    k:f32[] = div a b
    l:f32[1] = div j b
    m:f32[] = exp b
    n:f32[] = sub k m
    o:f32[] = div a b
    p:f32[1] = div j b
    q:f32[] = sin o
    r:f32[] = cos o
    s:f32[1] = mul p r
    t:f32[] = div a b
    u:f32[1] = div j b
    v:f32[] = add q t
    w:f32[1] = add s u
    x:f32[] = exp b
    y:f32[] = sub v x
    _:f32[] = mul n y
    z:f32[1] = mul l y
    ba:f32[1] = mul n w
    bb:f32[1] = add_any z ba
    bc:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] bb
    bd:f32[] = reshape[dimensions=None new_sizes=()] bc
  in (bd,) }

The iota and slice stuff in the beginning comes from the fact that jacfwd is more general than grad, but the rest reflects the un-optimized forward-mode gradient.

JaySandesara commented 1 year ago

Thanks a lot @jakevdp and @soraros for the clear explanations!

JaySandesara commented 1 year ago

Sorry to re-open, but I have one follow up question: As you suggested, I perform forward-mode autodiff using jacfwd and reverse mode using jacrev.

From what I understand, jacrev should be far more efficient when computing the gradient of a scalar valued function with a large input. However, for a toy computation like the following, I see the opposite!

Toy function (really doesn't do anything useful, just made it this way so I can pass arbitrarily large arrays):

def fn(tuple_arr):

    x = 1.0

    for param in tuple_arr:

        x *= (param**2-param**3-param)

    return x

Timing it:

%%timeit -r1 -n1

jax.jacrev(fn)(jnp.array(np.ones(1000))).block_until_ready()

Output:

46 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
%%timeit -r1 -n1

jax.jacfwd(fn)(jnp.array(np.ones(1000))).block_until_ready()

Output:

26.5 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Is this something to do with the function itself? Or that I am using CPU, which might be struggling with memory transfers? Considering that reverse-mode uses more memory?

jakevdp commented 1 year ago

A good rule of thumb is that any time you're writing for loops over array elements, you're going to end up with a very inefficient implementation. I'd probably write your function this way, in which case you see the computational characteristics you'd expect.

import jax 
import jax.numpy as jnp

def fn(tuple_arr):
  return jnp.prod(tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr)

%timeit jax.jacrev(fn)(jnp.ones(1000)).block_until_ready()
# 18 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit jax.jacfwd(fn)(jnp.ones(1000)).block_until_ready()
# 46.2 ms ± 8.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Maybe something about autodiffing a 1000-step unrolled loop makes jacrev less efficient than jacfwd? I'm not sure. In any case, that's a pattern you should avoid when possible.

JaySandesara commented 1 year ago

Thank you for the clarification!