apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.79k stars 6.79k forks source link

MinPy next step prototype #5566

Closed hotpxl closed 7 years ago

hotpxl commented 7 years ago

Need input from you guys! Especially ppl from MXNet team. We have been discussing this already for a while so background might not be clear to you guys.

Our goal

Integrate JIT and autograd with MXNet/MinPy python interface. Autograd integration is already on the way. We are proposing a uniform interface for both JIT and autograd.

For JIT, we cache user operations, and evaluate them lazily (only when user requests a print or asnumpy). By doing this, we can optimize computing sequence, and cache them for future use. It functions as a layer between Python/user code and NNVM/engine code.

As an example, user might have code in a tight loop. The graph structure generated in the loop are the same between iterations. In the first iteration, we optimize this computing sequence so that in future rounds we may use the optimized computing sequence to do calculation on different data. We need a way to detect and cache graph structure. That's the intention of this proposal.

for _ in range(interations):
    with minpy.jit():
        # code in loop

The boundary of JIT is defined by the context of minpy.jit and strict evaluations (print or asnumpy). Operations between boundaries are sent as whole to NNVM for graph optimization. Computing sequences are cached so each different structure is optimized only once.

For example, user might write

with minpy.jit():
    a = a + 3
    a = a * 4
    a = a / 100

In this case, three element-wise operations could be merged into one. The first time we encounter this code, we send the computing sequence + * / to NNVM for optimization. The second time, we look up our cache and run the optimized computing sequence instead.

There are many more corners cases, including those where JIT interact with autograd. Please refer to this gist for a proof-of-concept written in Python.

Implementation proposal

We intend to write the code in C-api directly, alongside NDArray functions and methods.

Header file is here. We need to intercept MXImperativeInvoke. Instead of calling underlying functions directly, we place them in our sequence buffer. By placing the function and its arguments, we assure the involved arrays are properly referenced and not freed prematurely. At a later stage, when JIT boundary is encountered, we flush the sequence buffer and push it to engine/NNVM. We achieve lazy evaluation in this way. A similar approach goes for autograd operations. When gradient sequences is calculated, we push them into the JIT queue so they can also get optimized. A sample (not complete) implementation code is here.

@mli @tqchen @piiswrong @jermainewang @HrWangChengdu @ZihengJiang @lryta @sneakerkg @zzhang-cn

tqchen commented 7 years ago

I see this is still execute and cache approach. Any consideration of utilizing python AST?

hotpxl commented 7 years ago

@tqchen do you mean intercepting python interpreting process?

I guess that would be too much work. It would require coding into CPython interpreter. (and will not work PyPy users).

Eventually the user interface will be just the minpy.jit() annotation, even in our execute-and-cache approach. I wonder if it could be any better on the user interface if we use AST. Anyways we rely on user annotation to optimize subgraph (at least for now).

Implementation-wise, we add this indirection in C-api level. I figure it would not impact performance at all.

tqchen commented 7 years ago

Always do something that mixes flavour.

I mean a pre-compilation phase, which analyses AST and transforms some of the fixed content into symbolic graph, which caches for fixed input size. This will reduce the cost of dynamic recording by pre-merging the static parts once.

Specifically. jit decorator can take a python function that contains imperative fragment, and extracts another it into blocks. aka, do a python to python transformation that keeps python code for control flow, but pre-compiles basic blocks into graph. Which changes the ndarray fragment into symbolic construction give no control flow. The parts with control flow can be done with the runtime recording approach

@jit
def my_func(a, b, c):
    return (a+b) + c

will directly change to

class Fragment(object):
    def _symbol():
         a = mx.sym.Variable()
         b = mx.sym.Variable()
         c = mx.sym.Variable()
         out = (a + b) + c
    def __call__(a, b, c):
         if self.in_cache(a.shape, b.shape, c.shape):
             self.executor.run(a, b, c)
         else:
             do_jit(a, b, c)

For a more complicated case with control flow

@jit
def my_func(a, b, c):
    if a < b:
       return (a+b) + c
   else:
       return a-b

This will become

class FragmentA(object):
    def _symbol1():
         a = mx.sym.Variable()
         b = mx.sym.Variable()
         c = mx.sym.Variable()
         out = a + b + c
    def _symbol2():
         a = mx.sym.Variable()
         b = mx.sym.Variable()
         c = mx.sym.Variable()
         out =a - b
    def frag1(a, b, c):
          if self.in_cache(a.shape, b.shape, c.shape):
             self.executor1.run(a, b, c)
         else:
             do_jit(a, b, c)

    def frag2(a, b):
          if self.in_cache(a.shape, b.shape, c.shape):
             self.executor1.run(a, b, c)
         else:
             do_jit(a, b, c)

    def __call__(a, b, c):
        if a < b:
           return self.frag1(a, b,c)
        else:
           return self.frag2(a,b)

Ideally the frag cache logics should be in C++

jermainewang commented 7 years ago

Great suggestion! This can be viewed as a preprocessing of python codes. We can assume the AST analysis will generate symbol segments that are connected by control flow. For now, we can first try fix the interface that given symbols and ndarrays, how to record/cache sequence and use NNVM to generate efficient executors. If you guys agree with this, we can first put the interface on minpy branch and once it is fixed, we can go ahead doing implementation.