Open shoyer opened 4 years ago
Hi @shoyer
Looks like this issue has been resolved in later versions of JAX. I executed the mentioned code on colab with JAX version 0.4.23. Now both the functions f1
and f2
are evaluated in exactly the same order with some extra calls to id_print
injected for the function f2
.
import jax
from jax.experimental import host_callback
def f1(x):
y = x ** 2
return y
def f2(x):
y = x ** 2
y = host_callback.id_print(y)
return y
print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))
print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0))
Output:
jvp without id_print:
{ lambda ; a:f32[] b:f32[]. let
c:f32[] = integer_pow[y=2] a
d:f32[] = integer_pow[y=1] a
e:f32[] = mul 2.0 d
f:f32[] = mul b e
in (c, f) }
jvp with id_print:
{ lambda ; a:f32[] b:f32[]. let
c:f32[] = integer_pow[y=2] a
d:f32[] = integer_pow[y=1] a
e:f32[] = mul 2.0 d
f:f32[] = mul b e
g:f32[] = outside_call[
arg_treedef=PyTreeDef(*)
callback=<jax.experimental.host_callback._CallbackWrapper object at 0x7d4237b05a80>
device_index=0
identity=True
] c
in (g, f) }
Since jax.experimental.host_callback
is deprecated (#20385), I have tested with jax.debug.print
and with it also the functions f1
and f2
are evaluated in the same order. Since jax.debug.print
returns None
, the function f2
also returns None
here.
import jax
def f1(x):
y = x ** 2
return y
def f2(x):
y = x ** 2
y = jax.debug.print("{}", y)
return y
print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))
print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0))
Output:
jvp without id_print:
{ lambda ; a:f32[] b:f32[]. let
c:f32[] = integer_pow[y=2] a
d:f32[] = integer_pow[y=1] a
e:f32[] = mul 2.0 d
f:f32[] = mul b e
in (c, f) }
jvp with id_print:
{ lambda ; a:f32[] b:f32[]. let
c:f32[] = integer_pow[y=2] a
d:f32[] = integer_pow[y=1] a
e:f32[] = mul 2.0 d
_:f32[] = mul b e
debug_callback[
callback=<function debug_callback.<locals>._flat_callback at 0x7d4235c7f5b0>
effect=Debug
] c
in () }
Please find the gist for reference.
Thank you.
(Forked from https://github.com/google/jax/issues/3127)
Consider the following example:
The function
f2
is exactly the same asf1
, except with the addition ofid_print
. Naively, I would expect these functions to be evaluated in exactly the same order, expect with some extra calls toid_tap
injected. But as we can see from the JAXprs, that isn't what happens:Without id_print, primals are evaluated before tangents. But with id_print, tangents are evaluated first!
This is a perfectly way to calculate JVPs, of course, but it's a little worrisome for a debugging utility to change how compute happens. It's all the more worrisome because JVP are implemented with tracers, which I would not expect to change the order of function evaluation. I can imagine this resulting in some very frustrating debugging sessions, e.g., if code crashes only during the tangent calculation.