Closed hotpxl closed 7 years ago
I see this is still execute and cache approach. Any consideration of utilizing python AST?
@tqchen do you mean intercepting python interpreting process?
I guess that would be too much work. It would require coding into CPython interpreter. (and will not work PyPy users).
Eventually the user interface will be just the minpy.jit()
annotation, even in our execute-and-cache approach. I wonder if it could be any better on the user interface if we use AST. Anyways we rely on user annotation to optimize subgraph (at least for now).
Implementation-wise, we add this indirection in C-api level. I figure it would not impact performance at all.
Always do something that mixes flavour.
I mean a pre-compilation phase, which analyses AST and transforms some of the fixed content into symbolic graph, which caches for fixed input size. This will reduce the cost of dynamic recording by pre-merging the static parts once.
Specifically. jit decorator can take a python function that contains imperative fragment, and extracts another it into blocks. aka, do a python to python transformation that keeps python code for control flow, but pre-compiles basic blocks into graph. Which changes the ndarray fragment into symbolic construction give no control flow. The parts with control flow can be done with the runtime recording approach
@jit
def my_func(a, b, c):
return (a+b) + c
will directly change to
class Fragment(object):
def _symbol():
a = mx.sym.Variable()
b = mx.sym.Variable()
c = mx.sym.Variable()
out = (a + b) + c
def __call__(a, b, c):
if self.in_cache(a.shape, b.shape, c.shape):
self.executor.run(a, b, c)
else:
do_jit(a, b, c)
For a more complicated case with control flow
@jit
def my_func(a, b, c):
if a < b:
return (a+b) + c
else:
return a-b
This will become
class FragmentA(object):
def _symbol1():
a = mx.sym.Variable()
b = mx.sym.Variable()
c = mx.sym.Variable()
out = a + b + c
def _symbol2():
a = mx.sym.Variable()
b = mx.sym.Variable()
c = mx.sym.Variable()
out =a - b
def frag1(a, b, c):
if self.in_cache(a.shape, b.shape, c.shape):
self.executor1.run(a, b, c)
else:
do_jit(a, b, c)
def frag2(a, b):
if self.in_cache(a.shape, b.shape, c.shape):
self.executor1.run(a, b, c)
else:
do_jit(a, b, c)
def __call__(a, b, c):
if a < b:
return self.frag1(a, b,c)
else:
return self.frag2(a,b)
Ideally the frag cache logics should be in C++
Great suggestion! This can be viewed as a preprocessing of python codes. We can assume the AST analysis will generate symbol segments that are connected by control flow. For now, we can first try fix the interface that given symbols and ndarrays, how to record/cache sequence and use NNVM to generate efficient executors. If you guys agree with this, we can first put the interface on minpy
branch and once it is fixed, we can go ahead doing implementation.
Need input from you guys! Especially ppl from MXNet team. We have been discussing this already for a while so background might not be clear to you guys.
Our goal
Integrate JIT and autograd with MXNet/MinPy python interface. Autograd integration is already on the way. We are proposing a uniform interface for both JIT and autograd.
For JIT, we cache user operations, and evaluate them lazily (only when user requests a print or
asnumpy
). By doing this, we can optimize computing sequence, and cache them for future use. It functions as a layer between Python/user code and NNVM/engine code.As an example, user might have code in a tight loop. The graph structure generated in the loop are the same between iterations. In the first iteration, we optimize this computing sequence so that in future rounds we may use the optimized computing sequence to do calculation on different data. We need a way to detect and cache graph structure. That's the intention of this proposal.
The boundary of JIT is defined by the context of
minpy.jit
and strict evaluations (print orasnumpy
). Operations between boundaries are sent as whole to NNVM for graph optimization. Computing sequences are cached so each different structure is optimized only once.For example, user might write
In this case, three element-wise operations could be merged into one. The first time we encounter this code, we send the computing sequence
+ * /
to NNVM for optimization. The second time, we look up our cache and run the optimized computing sequence instead.There are many more corners cases, including those where JIT interact with autograd. Please refer to this gist for a proof-of-concept written in Python.
Implementation proposal
We intend to write the code in C-api directly, alongside NDArray functions and methods.
Header file is here. We need to intercept
MXImperativeInvoke
. Instead of calling underlying functions directly, we place them in our sequence buffer. By placing the function and its arguments, we assure the involved arrays are properly referenced and not freed prematurely. At a later stage, when JIT boundary is encountered, we flush the sequence buffer and push it to engine/NNVM. We achieve lazy evaluation in this way. A similar approach goes for autograd operations. When gradient sequences is calculated, we push them into the JIT queue so they can also get optimized. A sample (not complete) implementation code is here.@mli @tqchen @piiswrong @jermainewang @HrWangChengdu @ZihengJiang @lryta @sneakerkg @zzhang-cn