Open mwhittaker opened 4 weeks ago
AFAIK If we implement EvalTracer, EvalTrace will see the program but
As an example,I modified the Evaluation interpreter cell in the autodidax adding EvalTracer implementation that propagates functions applied on it.
class EvalTrace(Trace):
pure = lift = lambda self, x: EvalTracer(self, x, f"{x}") # Literal
def process_primitive(self, primitive, tracers, params):
values_in, infos_in = unzip2((t.value, t.info) for t in tracers)
impl_rule = impl_rules[primitive]
print(f"{impl_rule.__name__}({values_in})")
values_outs, infos_out = impl_rule(values_in, infos_in, **params)
return [EvalTracer(self, x, t) for x, t in zip(values_outs, infos_out)]
class EvalTracer(Tracer):
def __init__(self, trace, value, info):
self._trace = trace
self.value = value
self.info = info
@property
def aval(self):
return get_aval(self.value)
trace_stack.append(MainTrace(0, EvalTrace, None)) # special bottom of the stack
# NB: in JAX, instead of a dict we attach impl rules to the Primitive instance
impl_rules = {}
def add_impl(value, info=None):
x, y = value
ix, iy = info
return [x + y], [f"add({ix}, {iy})" if info else ""]
def sin_impl(value, info=None):
x, = value
ix, = info
return [np.sin(x)], [f"sin({ix})" if info else ""]
def neg_impl(value, info=None):
x, = value
ix, = info
return [-x], [f"neg({ix})" if info else ""]
def mul_impl(value, info=None):
x, y = value
ix, iy = info
return [x * y], [f"mul({ix}, {iy})" if info else ""]
impl_rules[add_p] = add_impl
impl_rules[neg_p] = neg_impl
impl_rules[sin_p] = sin_impl
impl_rules[mul_p] = mul_impl
def eval_flat(f, values, info):
with new_main(EvalTrace) as main:
trace = EvalTrace(main)
tracers_in = [EvalTracer(trace, x, i) for x, i in zip(values, info)]
out = f(*tracers_in)
tracer_out = full_raise(trace, out)
value_out, info_out = tracer_out.value, tracer_out.info
return value_out, info_out
def broadcast_impl(x, *, shape, axes):
for axis in sorted(axes):
x = np.expand_dims(x, axis)
return [np.broadcast_to(x, shape)]
impl_rules[broadcast_p] = broadcast_impl
and the subsequent cell to
def f(x):
y = sin(x) * 2.
z = - y + x
print("im in function body z=",z)
return z
_, info = eval_flat(f, (3.,), ("x", ))
print(f"{info=}")
Prints
sin_impl([3.0])
mul_impl([0.1411200080598672, 2.0])
neg_impl([0.2822400161197344])
add_impl([-0.2822400161197344, 3.0])
im in function body z= <__main__.EvalTracer object at 0x7aee37f28880>
info='add(neg(mul(sin(x), 2.0)), x)'
LMK if this helps
Description
I've been reading through the autodidax documentation, which is awesome! Just after the evaluation interpreter named
EvalTrace
is introduced, there is a code snippet that shows how to use the evaluation interpreter to evaluate a user function:https://github.com/jax-ml/jax/blob/eff6cb445b769e2b0aa0bbff0888f4d3c6713b43/docs/autodidax.py#L441-L453
However, as the code is written,
EvalTrace
is only used to interpret the very first call to thesin
function. All other operations in the user function are evaluated as regular Python. If you add a print statement to theEvalTrace.process_primitive
method, you can see it is printed only once. I believe this happens becauseEvalTrace
does not use anyTracer
s. As a result, the return value ofsin(x)
is a regular Python number, so all subsequent operations performed on it do not involveEvalTrace
.If this is the intended behavior, please ignore this issue, but I expected the code snippet to be fully evaluated by
EvalTrace
.System info (python version, jaxlib version, accelerator, etc.)
N/A