Closed JaySandesara closed 1 year ago
@JaySandesara For your Q1: Eager mode (not jitted) JAX runs fnc_jax
op-by-op, so indeed the computations are duplicated, and jax.make_jaxpr
only reflects this fact. However, if you jit your function, CSE will kick-in, and divide
only occurred once:
print(jit(fnc_jax).lower(1., 1.).compile().as_text())
HloModule jit_fnc_jax, entry_computation_layout={(f64[],f64[])->f64[]}, allow_spmd_sharding_propagation_to_output={true}
fused_computation {
param_0.2 = f64[] parameter(0)
param_1.4 = f64[] parameter(1)
exponential.0 = f64[] exponential(param_1.4)
subtract.1 = f64[] subtract(param_0.2, exponential.0)
sine.0 = f64[] sine(param_0.2)
add.0 = f64[] add(sine.0, param_0.2)
subtract.0 = f64[] subtract(add.0, exponential.0)
ROOT multiply.0 = f64[] multiply(subtract.1, subtract.0)
}
ENTRY main.13 {
Arg_0.1 = f64[] parameter(0), sharding={replicated}
Arg_1.2 = f64[] parameter(1), sharding={replicated}
divide.3 = f64[] divide(Arg_0.1, Arg_1.2)
ROOT fusion = f64[] fusion(divide.3, Arg_1.2), kind=kLoop, calls=fused_computation
}
Assuming your functions in JAX will be jitted, you generally don't need to worry about duplicate computation, and you can code it the way that looks cleanest/makes most sense conceptually.
Question 1: Why is it that c, e, and h for example are doing the exact same computation?
The reason for this is that jaxprs do not produce optimized code, they merely produce an intermediate representation of the code you wrote. Because you computed jnp.divide(x, y)
three times, the computation is represented three times in the jaxpr.
Never fear, though: when you JIT-compile the code, the compiler recognizes these repeated operations and will de-duplicate them, freeing you from having to worry about doing so manually. You can see this by printing the compiled HLO, though it's admittedly harder to read than the jaxpr:
print(jax.jit(fnc_jax).lower(1.0, 1.0).compile().as_text())
%fused_computation (param_0.2: f32[], param_1.4: f32[]) -> f32[] {
%param_0.2 = f32[] parameter(0)
%param_1.4 = f32[] parameter(1)
%exponential.0 = f32[] exponential(f32[] %param_1.4), metadata={op_name="jit(fnc_jax)/jit(main)/exp" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
%subtract.1 = f32[] subtract(f32[] %param_0.2, f32[] %exponential.0), metadata={op_name="jit(fnc_jax)/jit(main)/sub" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
%sine.0 = f32[] sine(f32[] %param_0.2), metadata={op_name="jit(fnc_jax)/jit(main)/sin" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
%add.0 = f32[] add(f32[] %sine.0, f32[] %param_0.2), metadata={op_name="jit(fnc_jax)/jit(main)/add" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
%subtract.0 = f32[] subtract(f32[] %add.0, f32[] %exponential.0), metadata={op_name="jit(fnc_jax)/jit(main)/sub" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
ROOT %multiply.0 = f32[] multiply(f32[] %subtract.1, f32[] %subtract.0), metadata={op_name="jit(fnc_jax)/jit(main)/mul" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
}
ENTRY %main.13 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] {
%Arg_0.1 = f32[] parameter(0), sharding={replicated}
%Arg_1.2 = f32[] parameter(1), sharding={replicated}
%divide.3 = f32[] divide(f32[] %Arg_0.1, f32[] %Arg_1.2), metadata={op_name="jit(fnc_jax)/jit(main)/div" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
ROOT %fusion = f32[] fusion(f32[] %divide.3, f32[] %Arg_1.2), kind=kLoop, calls=%fused_computation, metadata={op_name="jit(fnc_jax)/jit(main)/mul" source_file="<ipython-input-5-923f7c7e31f1>" source_line=5}
}
Question 2: Apart from repeated computations, what confuses me is what exactly is going on here?
jax.grad
is implemented in terms of reverse-mode differentiation, so the generated jaxpr reflects the gradient computed in reverse mode. There's no exact forward-mode equivalent of jax.grad
, but you can get roughly the same thing by computing the jaxpr of jax.jacfwd
:
print(jax.make_jaxpr(jax.jacfwd(fnc_jax))(1.0,1.0))
{ lambda ; a:f32[] b:f32[]. let
c:i32[1,1] = iota[dimension=0 dtype=int32 shape=(1, 1)]
d:i32[1,1] = add c 0
e:i32[1,1] = iota[dimension=1 dtype=int32 shape=(1, 1)]
f:bool[1,1] = eq d e
g:f32[1,1] = convert_element_type[new_dtype=float32 weak_type=False] f
h:f32[1,1] = slice[limit_indices=(1, 1) start_indices=(0, 0) strides=None] g
i:f32[1] = reshape[dimensions=None new_sizes=(1,)] h
j:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] i
k:f32[] = div a b
l:f32[1] = div j b
m:f32[] = exp b
n:f32[] = sub k m
o:f32[] = div a b
p:f32[1] = div j b
q:f32[] = sin o
r:f32[] = cos o
s:f32[1] = mul p r
t:f32[] = div a b
u:f32[1] = div j b
v:f32[] = add q t
w:f32[1] = add s u
x:f32[] = exp b
y:f32[] = sub v x
_:f32[] = mul n y
z:f32[1] = mul l y
ba:f32[1] = mul n w
bb:f32[1] = add_any z ba
bc:f32[1] = slice[limit_indices=(1,) start_indices=(0,) strides=None] bb
bd:f32[] = reshape[dimensions=None new_sizes=()] bc
in (bd,) }
The iota
and slice
stuff in the beginning comes from the fact that jacfwd
is more general than grad
, but the rest reflects the un-optimized forward-mode gradient.
Thanks a lot @jakevdp and @soraros for the clear explanations!
Sorry to re-open, but I have one follow up question: As you suggested, I perform forward-mode autodiff using jacfwd and reverse mode using jacrev.
From what I understand, jacrev should be far more efficient when computing the gradient of a scalar valued function with a large input. However, for a toy computation like the following, I see the opposite!
Toy function (really doesn't do anything useful, just made it this way so I can pass arbitrarily large arrays):
def fn(tuple_arr):
x = 1.0
for param in tuple_arr:
x *= (param**2-param**3-param)
return x
Timing it:
%%timeit -r1 -n1
jax.jacrev(fn)(jnp.array(np.ones(1000))).block_until_ready()
Output:
46 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
%%timeit -r1 -n1
jax.jacfwd(fn)(jnp.array(np.ones(1000))).block_until_ready()
Output:
26.5 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)
Is this something to do with the function itself? Or that I am using CPU, which might be struggling with memory transfers? Considering that reverse-mode uses more memory?
A good rule of thumb is that any time you're writing for
loops over array elements, you're going to end up with a very inefficient implementation. I'd probably write your function this way, in which case you see the computational characteristics you'd expect.
import jax
import jax.numpy as jnp
def fn(tuple_arr):
return jnp.prod(tuple_arr ** 2 - tuple_arr ** 3 - tuple_arr)
%timeit jax.jacrev(fn)(jnp.ones(1000)).block_until_ready()
# 18 ms ± 2.06 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit jax.jacfwd(fn)(jnp.ones(1000)).block_until_ready()
# 46.2 ms ± 8.05 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Maybe something about autodiffing a 1000-step unrolled loop makes jacrev less efficient than jacfwd? I'm not sure. In any case, that's a pattern you should avoid when possible.
Thank you for the clarification!
Hello,
I am trying to understand exactly how auto-diff works with JAX. I have an example function:
Now I do
jax.make_jaxpr(fnc_jax)(1.0,1.0)
, which gives me the following output:Question 1: Why is it that c, e, and h for example are doing the exact same computation? Wouldn't a better way be if we had something like this:
which gives the same output for a given input as
fnc_jax
, but doesnt repeat computations confirmed usingjax.make_jaxpr(fnc_jax_alt)(1.0,1.0)
. Is this a feature or a bug?Another confusion I have is when I do
jax.make_jaxpr(jax.grad(fnc_jax))(1.0,1.0)
:Question 2: Apart from repeated computations, what confuses me is what exactly is going on here? It seems to me like it is doing something of a forward differentiation. But if I understand correctly from the documentation,
jax.grad
performs reverse-mode differentiation by default. Am I wrong or mis-understanding the make_jaxpr?If it is forward-mode differentiation, then should we not have something like the following:
If it is reverse-mode differentiation, how can I get a
make_jaxpr
for forward-mode?Apologies in advance for my mis-understandings. I would appreciate some clarity on these questions!
Regards, Jay Sandesara
What jax/jaxlib version are you using?
No response
Which accelerator(s) are you using?
No response
Additional system info
No response
NVIDIA GPU info
No response