jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.03k stars 2.75k forks source link

Support for stateful traces #1663

Open juliuskunze opened 4 years ago

juliuskunze commented 4 years ago

As I understand it, implementing function transformations final-style should be preferred over initial-style whenever possible because the latter (while being more general) usually results in more complex code. I found that the scope of what's possible with final-style transformations can be expanded by attaching mutable state to the master trace:

with new_master(InitTrace) as master:
    master.state = InitTraceState(rng)

which is then accessible from within the trace class:

rng = self.master.state.next_rng()

This approach allowed to massively simplify JAXnet's implementation by

I got rid of any custom stateful stacks / context managers, meaning that JAXnet handles state solely through (this extended version of) JAX's tracing mechanisms. Another potential use-case would be fastar, where both firstpass and fastpass could be converted from initial- to final-style, again simplifying code by getting rid of two near-copies of a jaxpr interpreter.

Before I try to rewrite fastar in this way: Is this unintended for some reason? If not, shouldn't JAX provide a streamlined / less hacky way to have a stateful MasterTrace? I. e.

with new_master(InitTrace, state=InitTraceState(rng)) as master:

and adding

@property
def state(self):
    return self.master.state

into the Trace base class. To me, it looks like this would encourage writing simpler function transformation code in many cases.

jekbradbury commented 4 years ago

I think you're mixing up the "initial" and "final" terms here: "initial style" means operating on jaxprs, while "final style" means a Tracer subclass. Initial style is easier to reason about and more general, but final style has the advantage of a potentially better debugging experience for users and the ability to support more dynamic code. It's not clear which of them results in simpler or more complex code in general; it's certainly possible to end up with complicated implementations either way.

My TagTracer was written as a stateful trace, like you describe. Unfortunately, it's not enough to propagate trace state by attaching it to the Trace class (e.g. in a way that crosses subtrace boundaries). Instead you have to make sure that any parts of the trace state that could themselves contain tracers (from traces nested inside your trace) are pulled out of the Trace class and passed as arguments and return values of your transformed function, so that they can be wrapped and unwrapped by other linear_util transformations stacked with yours. You can see my commit making this change here—I also had to worry about maintaining dict object identity across subtrace boundaries so that I could keep the trace-global nature of my state. You can check if you're susceptible to the same bug by trying your_trace(jit(f)).

At the end of the day I'm not sure that hanging the state on the Trace ended up being all that much easier than squeezing it along dataflow or writing an initial-style interpreter, but I do agree that the outcome is easier to reason about than dataflow squeezing and supports more code than an initial-style transform.

I'd also recommend seeing if the tagging infra can be useful for either jaxnet or fastar—it should cover many cases where you can think about your task as extracting some subgraph of a larger computation and then transforming it; I think I'm reasonably happy with using it for neural net initialization (although with an interface that's somewhat different from jaxnet's). An even better choice than initial or final style for implementing a new transformation is to compose it out of other ones 🙂.

juliuskunze commented 4 years ago

My bad, have now swapped the terms! Btw, do you know of any references introducing / using these terms? Why are they called that way?

Indeed, parametrized(jit(fun)) is currently an issue, thank you for those pointers!

I can clearly see (and am super excited about!) how the tagging infrastructure will be useful for the first three use cases you mentioned in https://github.com/google/jax/pull/1660 (logging, intermediate gradients, probabilistic programming). However, I am skeptical about writing a combined init-and-apply function with tagged weights, as explored in https://github.com/google/jax/pull/1341: Imagine the amount of boilerplate code this would require for random key splitting and distribution in a complex model (i. e. wavenet or pixelcnn). In my view, being able to automate this aspect in a functional way using JAX's transformation engine is a gift, and it would be a waste to not make use of it! My goal is to build the best possible API for learning parametrized functions. From this perspective, I struggle to see any advantage of explicitly splitting keys (with obvious downsides), especially because the order/structure of how keys are split shouldn't matter conceptually.

Update: Random key handling can be automated only using tagging transformations, see https://github.com/google/jax/pull/1341#issuecomment-553439580.

"An even better choice than initial or final style for implementing a new transformation is to compose it out of other ones" - totally agree! I have the hunch that the tagging transformations will power a lot of good things in the future, some of which we haven't even thought of yet. 🙂

shoyer commented 4 years ago

Another use case for stateful traces would be a tracer for in-place assignment, i.e., supporting x[i] = y. In some cases this might result in considerably more readable code than what we get with explicit index_update calls.