Open juliuskunze opened 5 years ago
I think you're mixing up the "initial" and "final" terms here: "initial style" means operating on jaxprs, while "final style" means a Tracer
subclass. Initial style is easier to reason about and more general, but final style has the advantage of a potentially better debugging experience for users and the ability to support more dynamic code. It's not clear which of them results in simpler or more complex code in general; it's certainly possible to end up with complicated implementations either way.
My TagTracer
was written as a stateful trace, like you describe. Unfortunately, it's not enough to propagate trace state by attaching it to the Trace
class (e.g. in a way that crosses subtrace boundaries). Instead you have to make sure that any parts of the trace state that could themselves contain tracers (from traces nested inside your trace) are pulled out of the Trace class and passed as arguments and return values of your transformed function, so that they can be wrapped and unwrapped by other linear_util transformations stacked with yours. You can see my commit making this change here—I also had to worry about maintaining dict object identity across subtrace boundaries so that I could keep the trace-global nature of my state. You can check if you're susceptible to the same bug by trying your_trace(jit(f)).
At the end of the day I'm not sure that hanging the state on the Trace
ended up being all that much easier than squeezing it along dataflow or writing an initial-style interpreter, but I do agree that the outcome is easier to reason about than dataflow squeezing and supports more code than an initial-style transform.
I'd also recommend seeing if the tagging infra can be useful for either jaxnet or fastar—it should cover many cases where you can think about your task as extracting some subgraph of a larger computation and then transforming it; I think I'm reasonably happy with using it for neural net initialization (although with an interface that's somewhat different from jaxnet's). An even better choice than initial or final style for implementing a new transformation is to compose it out of other ones 🙂.
My bad, have now swapped the terms! Btw, do you know of any references introducing / using these terms? Why are they called that way?
Indeed, parametrized(jit(fun))
is currently an issue, thank you for those pointers!
I can clearly see (and am super excited about!) how the tagging infrastructure will be useful for the first three use cases you mentioned in https://github.com/google/jax/pull/1660 (logging, intermediate gradients, probabilistic programming). However, I am skeptical about writing a combined init-and-apply function with tagged weights, as explored in https://github.com/google/jax/pull/1341: Imagine the amount of boilerplate code this would require for random key splitting and distribution in a complex model (i. e. wavenet or pixelcnn). In my view, being able to automate this aspect in a functional way using JAX's transformation engine is a gift, and it would be a waste to not make use of it! My goal is to build the best possible API for learning parametrized functions. From this perspective, I struggle to see any advantage of explicitly splitting keys (with obvious downsides), especially because the order/structure of how keys are split shouldn't matter conceptually.
Update: Random key handling can be automated only using tagging transformations, see https://github.com/google/jax/pull/1341#issuecomment-553439580.
"An even better choice than initial or final style for implementing a new transformation is to compose it out of other ones" - totally agree! I have the hunch that the tagging transformations will power a lot of good things in the future, some of which we haven't even thought of yet. 🙂
Another use case for stateful traces would be a tracer for in-place assignment, i.e., supporting x[i] = y
. In some cases this might result in considerably more readable code than what we get with explicit index_update
calls.
As I understand it, implementing function transformations final-style should be preferred over initial-style whenever possible because the latter (while being more general) usually results in more complex code. I found that the scope of what's possible with final-style transformations can be expanded by attaching mutable state to the master trace:
which is then accessible from within the trace class:
This approach allowed to massively simplify JAXnet's implementation by
_init_transform
) from initial- to final-style,_apply_transform
) by freeing it from "squeezing" state along data flow trajectories.I got rid of any custom stateful stacks / context managers, meaning that JAXnet handles state solely through (this extended version of) JAX's tracing mechanisms. Another potential use-case would be fastar, where both
firstpass
andfastpass
could be converted from initial- to final-style, again simplifying code by getting rid of two near-copies of a jaxpr interpreter.Before I try to rewrite fastar in this way: Is this unintended for some reason? If not, shouldn't JAX provide a streamlined / less hacky way to have a stateful
MasterTrace
? I. e.and adding
into the
Trace
base class. To me, it looks like this would encourage writing simpler function transformation code in many cases.