jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.6k stars 2.82k forks source link

autodidax `EvalTrace` isn't fully evaluating? #24585

Open mwhittaker opened 4 weeks ago

mwhittaker commented 4 weeks ago

Description

I've been reading through the autodidax documentation, which is awesome! Just after the evaluation interpreter named EvalTrace is introduced, there is a code snippet that shows how to use the evaluation interpreter to evaluate a user function:

https://github.com/jax-ml/jax/blob/eff6cb445b769e2b0aa0bbff0888f4d3c6713b43/docs/autodidax.py#L441-L453

However, as the code is written, EvalTrace is only used to interpret the very first call to the sin function. All other operations in the user function are evaluated as regular Python. If you add a print statement to the EvalTrace.process_primitive method, you can see it is printed only once. I believe this happens because EvalTrace does not use any Tracers. As a result, the return value of sin(x) is a regular Python number, so all subsequent operations performed on it do not involve EvalTrace.

If this is the intended behavior, please ignore this issue, but I expected the code snippet to be fully evaluated by EvalTrace.

System info (python version, jaxlib version, accelerator, etc.)

N/A

ASEM000 commented 3 weeks ago

AFAIK If we implement EvalTracer, EvalTrace will see the program but

As an example,I modified the Evaluation interpreter cell in the autodidax adding EvalTracer implementation that propagates functions applied on it.

class EvalTrace(Trace):
  pure = lift = lambda self, x: EvalTracer(self, x, f"{x}") # Literal

  def process_primitive(self, primitive, tracers, params):
    values_in, infos_in = unzip2((t.value, t.info) for t in tracers)
    impl_rule = impl_rules[primitive]
    print(f"{impl_rule.__name__}({values_in})")
    values_outs, infos_out = impl_rule(values_in, infos_in, **params)
    return [EvalTracer(self, x, t) for x, t in zip(values_outs, infos_out)]

class EvalTracer(Tracer):
    def __init__(self, trace, value, info):
        self._trace = trace
        self.value = value
        self.info = info

    @property
    def aval(self):
        return get_aval(self.value)

trace_stack.append(MainTrace(0, EvalTrace, None))  # special bottom of the stack

# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}

def add_impl(value, info=None):
    x, y = value
    ix, iy = info
    return [x + y], [f"add({ix}, {iy})" if info else ""]

def sin_impl(value, info=None):
    x, = value
    ix, = info
    return [np.sin(x)], [f"sin({ix})" if info else ""]

def neg_impl(value, info=None):
    x, = value
    ix, = info
    return [-x], [f"neg({ix})" if info else ""]

def mul_impl(value, info=None):
    x, y = value
    ix, iy = info
    return [x * y], [f"mul({ix}, {iy})" if info else ""]

impl_rules[add_p] = add_impl
impl_rules[neg_p] = neg_impl
impl_rules[sin_p] = sin_impl
impl_rules[mul_p] = mul_impl

def eval_flat(f, values, info):
  with new_main(EvalTrace) as main:
    trace = EvalTrace(main)
    tracers_in = [EvalTracer(trace, x, i) for x, i in zip(values, info)]
    out = f(*tracers_in)
    tracer_out = full_raise(trace, out)
    value_out, info_out = tracer_out.value, tracer_out.info
  return value_out, info_out

def broadcast_impl(x, *, shape, axes):
  for axis in sorted(axes):
    x = np.expand_dims(x, axis)
  return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl

and the subsequent cell to

def f(x):
  y = sin(x) * 2.
  z = - y + x
  print("im in function body z=",z)
  return z

_, info = eval_flat(f, (3.,), ("x", ))
print(f"{info=}")

Prints

sin_impl([3.0])
mul_impl([0.1411200080598672, 2.0])
neg_impl([0.2822400161197344])
add_impl([-0.2822400161197344, 3.0])
im in function body z= <__main__.EvalTracer object at 0x7aee37f28880>
info='add(neg(mul(sin(x), 2.0)), x)'

LMK if this helps