google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.95k stars 2.75k forks source link

Adding host_callbacks.id_tap reorders JVP evaluation #3198

Open shoyer opened 4 years ago

shoyer commented 4 years ago

(Forked from https://github.com/google/jax/issues/3127)

Consider the following example:

import jax
from jax.experimental import host_callback

def f1(x):
  y = x ** 2
  return y

def f2(x):
  y = x ** 2
  y = host_callback.id_print(y)
  return y

print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))

print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0))

The function f2 is exactly the same as f1, except with the addition of id_print. Naively, I would expect these functions to be evaluated in exactly the same order, expect with some extra calls to id_tap injected. But as we can see from the JAXprs, that isn't what happens:

jvp without id_print:
{ lambda  ; a b.
  let c = integer_pow[ y=2 ] a
      d = mul 2.0 a
      e = mul b d
  in (c, e) }

jvp with id_print:
{ lambda  ; a b.
  let c = mul 2.0 a
      d = mul b c
      e = integer_pow[ y=2 ] a
      f = id_tap[ arg_treedef=*
                  func=<function _print_consumer at 0x7f3ee7d7f620>
                  nr_untapped=0
                  output_stream=None
                  threshold=None ] e
      g h = id_tap[ arg_treedef=*
                    func=<function _print_consumer at 0x7f3ee7d7f620>
                    nr_untapped=1
                    output_stream=None
                    threshold=None
                    transforms=(('jvp',),) ] d f
  in (f, g) }

Without id_print, primals are evaluated before tangents. But with id_print, tangents are evaluated first!

This is a perfectly way to calculate JVPs, of course, but it's a little worrisome for a debugging utility to change how compute happens. It's all the more worrisome because JVP are implemented with tracers, which I would not expect to change the order of function evaluation. I can imagine this resulting in some very frustrating debugging sessions, e.g., if code crashes only during the tangent calculation.

rajasekharporeddy commented 5 months ago

Hi @shoyer

Looks like this issue has been resolved in later versions of JAX. I executed the mentioned code on colab with JAX version 0.4.23. Now both the functions f1 and f2 are evaluated in exactly the same order with some extra calls to id_print injected for the function f2.

import jax
from jax.experimental import host_callback

def f1(x):
  y = x ** 2
  return y

def f2(x):
  y = x ** 2
  y = host_callback.id_print(y)
  return y

print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))

print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0))

Output:

jvp without id_print:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = integer_pow[y=2] a
    d:f32[] = integer_pow[y=1] a
    e:f32[] = mul 2.0 d
    f:f32[] = mul b e
  in (c, f) }

jvp with id_print:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = integer_pow[y=2] a
    d:f32[] = integer_pow[y=1] a
    e:f32[] = mul 2.0 d
    f:f32[] = mul b e
    g:f32[] = outside_call[
      arg_treedef=PyTreeDef(*)
      callback=<jax.experimental.host_callback._CallbackWrapper object at 0x7d4237b05a80>
      device_index=0
      identity=True
    ] c
  in (g, f) }

Since jax.experimental.host_callback is deprecated (#20385), I have tested with jax.debug.print and with it also the functions f1 and f2 are evaluated in the same order. Since jax.debug.print returns None, the function f2 also returns None here.

import jax

def f1(x):
  y = x ** 2
  return y

def f2(x):
  y = x ** 2
  y = jax.debug.print("{}", y)
  return y

print('jvp without id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f1), (x,), (y,)))(0.0, 0.0))

print('\njvp with id_print:')
print(jax.make_jaxpr(lambda x, y: jax.jvp((f2), (x,), (y,)))(0.0, 0.0))

Output:

jvp without id_print:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = integer_pow[y=2] a
    d:f32[] = integer_pow[y=1] a
    e:f32[] = mul 2.0 d
    f:f32[] = mul b e
  in (c, f) }

jvp with id_print:
{ lambda ; a:f32[] b:f32[]. let
    c:f32[] = integer_pow[y=2] a
    d:f32[] = integer_pow[y=1] a
    e:f32[] = mul 2.0 d
    _:f32[] = mul b e
    debug_callback[
      callback=<function debug_callback.<locals>._flat_callback at 0x7d4235c7f5b0>
      effect=Debug
    ] c
  in () }

Please find the gist for reference.

Thank you.