apache / tvm

Open deep learning compiler stack for cpu, gpu and specialized accelerators
https://tvm.apache.org/
Apache License 2.0
11.6k stars 3.44k forks source link

[RFC] NNVMv2 IR - Relay #1673

Closed jroesch closed 5 years ago

jroesch commented 6 years ago

[RFC]: Relay a new high level IR for TVM

Relay is a new high level intermediate representation (IR) intended to act as v2.0 of NNVM.

Motivation

Computation graphs are a powerful program representation as demonstrated by the first generation of DL frameworks. Most popular frameworks have employed computation graphs as their input, intermediate representation, and execution data structure.

However, as workloads continue to evolve, the design of our high level IRs needs to evolve to better support the needs of developers and users

Graph-level challenges such as control flow and sub-graphs have become necessary features to natively support and optimize.

The tight coupling between runtime representation and compile-time representation has limited flexibility and frustrated developers; Relay will decouple the representations.

Finally we believe the high level must be designed in tandem with the low level IR, allowing for the two layers to communicate during compilation to achieve optimal performance.

Design

The first version of NNVM set out to solve some of these challenges, and we view Relay as second generation IR designed specifically for integration into the TVM stack as the input layer. Our goal is to focus on TVM as our primary backend, easing development and maintenance for both TVM developers and current NNVM users, as well as enabling new features.

In order to address the challenges presented above we designed Relay to build on the things computation graphs are good at (pure, dataflow, compositional), and improve on the things they struggle with (control flow, subgraph, runtime/compilation distinction).

Core IR

Relay is a typed pure functional IR, with a few basic features such as functions, if-then-else control flow, recursion, operator and function calls, and variable binding.

We have iterated on Relay's design over the past 8 months. This versions represents the culmination of our experiments. This PR does not contain all the pieces of the previous version, instead we focus on introducing the core IR, its associated data structures, and a few integral passes.

The core IR is defined in just a few files:

Typing

All Relay programs are typed, similar to more conventional languages such as C++. A type system allows us to statically (i.e at compile time) distinguish between different sorts of values. This means we know whether an expression will evaluate to a tensor, a function (i.e (float32, float32) -> float32) or a tuple (float32, int32). Furthermore, our type system has the ability to be shape generic (i.e polymorphism, templating).

Type inference and checking take the place of shape inference in traditional computation graphs style IRs.

This PR implements type inference and checking for Relay, the code can be found in src/tvm/relay/pass/type_infer.cc, and relevant helper utilities in src/tvm/relay/pass.

Control Flow

Relay adds a notion of control flow to the IR, in the form of simple if (cond) { true_branch } else { false_branch}. Relay requires that the condition variable computes a single boolean value controlling which branch is taken. if is an expression in Relay, meaning the result of the entire expression is the result of the branch taken.

We introduce this to add a formal way to distinguish between data flow and control flow without having to conflate the two in the representation. Because we separate the control signal, we can easily batch a program without affecting control flow.

The definition of control flow can be found in include/tvm/relay/expr.h.

Abstraction

Relay supports the definition of functions which can be used to represent "sub-graphs" (i.e chunks of reusable computation).

Relay functions are like traditional functions: they represent some set of parameters (i.e placeholders) and a body which is a chunk of computation involving the the parameters (i.e sub-graph). We can build a full network/model by composing together functions.

Compilation

The Relay IR is designed as a compile time representation of models. The new features are exposed only in Relay's abstract syntax tree, and used for compile time program manipulation. We do not intend to use Relay's IR as a data structure for serious interpretation or execution.

Runtime

These new features increase the expressivity of the current computation model, and one may ask how to execute programs using these features with the existing runtime. Our goal is to introduce Relay as the compiler representation in this PR, and reuse the existing runtime maintaining compatibility on both the frontend and backend. We anticipate a new version of the runtime having native support for Relay's new constructs in the future.

TVM Co-design

We made an effort to model Relay's implementation after TVM and reuse much of the existing infrastructure in order to provide better compatibility between TOPI operators and Relay programs. One big design decision is reusing the TVM node system to expose the Relay language to Python in the style of TVM. Users who are familiar with TVM's expression language should feel comfortable working with the Relay AST's definition in C++, and Python. We also share representations for many data structures. For example tensor containers (i.e tvm::runtime::NDArray), and generic attributes which can be shared between Relay and TVM are two such shared structures.

Transitioning from NNVM

We plan on adding a guide for transitioning programs from NNVM to Relay. This is one of the remaining work items before releasing the Relay Alpha. The goal is users can use the Relay operators and builder API to construct Relay programs, and we will follow-up with a compatibility layer to make transitioning from NNVM smooth.

For an implementation see #1672 which implements this bit.

masahi commented 6 years ago

hi @jroesch, looks cool. I am hoping that this is a TVM's answer to Tensorflow's somewhat awkward tf.while_loop or other control flow constructs.

Regarding this sentence,

Finally we believe the high level must be designed in tandem with the low level IR, allowing for the two layers to communicate during compilation to achieve optimal performance.,

does HalideIR correspond to the "low level IR"? This sounds very interesting.

junrushao commented 6 years ago

We know that the goal of designing a new IR is to benefit potential optimization, so could you be more specific that what kinds of optimization Relay is planning to support?

tqchen commented 6 years ago

@junrushao1994 can you also elaborate on your use-cases and how easy/hard it is to bring this to the current proposal?

tqchen commented 6 years ago

This is one of the major design chance and we would love to have participation from the community to review and improve the proposal @dmlc/tvm-team

tqchen commented 6 years ago

Thanks @jroesch for the proposal, I am going to elaborate some of my take on this proposal. Note that no design is perfect and that is why we need help of everyone to work together the evolve the IR.

Specific Technical Points

Some Possible Point to Discuss

These are things that pops up from my head, feel free to add more.

junrushao commented 6 years ago

I would say the design itself is perfect, which addresses almost all problems that this design targets. Yes, this is the first systematic approach to address the lack of Turing-completeness in deep learning frameworks, rather than quick hacks like TensorFlow's while_loop. Also, The implementation is elegant and I love it.

So please allow me to take the liberty to talk about some concerns that might be out of the target of the current design. Briefly, I will comment in the following aspects.

junrushao commented 6 years ago

Part 1: Frontend

Sometimes I prefer to think that worse is better, and guess that it might be kind of restrictive to ask deep learning practitioners to do "the right thing". After all, we could not assume each user has a PhD degree in PL. Here are several things we might need to consider in the future.

Container types

I would love to discuss the necessity of having containers like List<T>, Dict<K, V>. This concern raises from some deep learning model in NLP. For example, self-attention, which is used in sequence generation tasks.

def self-attentive-generator():
  initialize states
  initialize outputs = []
  while True:
    prev_outputs = concat(outputs)
    context_vector = self-attention(prev_outputs, inputs, state)
    step_output, states = DecoderStep(states, prev_outputs)
    outputs.append(step_output)
    if some_condition:
      break
  outputs = concat(outputs)
  return outputs, states

We already have Tuple<T_1, ...> in Relay, which is great. We could definitely ask users to convert it to a FP style so that everything has no side effect, and while performance is not quite affected, but are we able to reduce the memory footprint in this case?

Supporting Dict<K, V> is somewhat weird requirement, but it seems to me that states in RNN are often represented using a dict (need confirmation from @szha). I guess we could probably replace this with something like namedtuple.

Other weird use cases may include Fast WaveNet, Pixel CNN++, beam search in decoding, many of while requires users to write a normal program to manipulate containers, which is easier for common users to write in an imperative style, rather than FP. I am a big fan of FPs like Haskell, but yet a little bit worried 1) whether the market likes this style of programming. 2) whether this IR could support such optimizations.

Side effects, e.g. I/O and random monad

For example, in dropout, we should definitely introduce something with randomness.

This is just a remainder that these are stuff we should take into consideration.

Whether to incorporate context into the type system

I am also wondering if ctx could be put into the type system. This is kind of co-design with TVM.

Primitives related to distributed system

This is kind of off the topic, but I am personally interested in seeing brilliant ideas about how we could handle these situations.

A. SyncBN: This introduces allreduce in the forward pass. Of course, such ugly thing like forward/backward will never exist in Relay, which is great, but it seems that we should introduce synchronization primitives in the design.

B. Timeout: It is common practice on an edge device to store a smaller DL model offline, which is used to produce a coarse result; in the meantime, an online model on remote servers computes some part of and send it back. In case the online model failed to send the fine-grind result in time, the edge device would use the offline result. This is another kind of side-effect, should we handle this in a deep learning framework, or leave it to others?

tqchen commented 6 years ago

+1 on making things accessible(worse is better), this is exactly what we should push for.

Making most part functional makes differentiation easy and allows build the things around. It is important to be able to support mutation in some of the outer loops, where diff is not needed. Have such clear distinction is important.

The List Dict programming style is more like multi staging program, where the graph is staged in the data structure before we call backward. While it is definitely possible to do so via imperative autodiff, it is an interesting question to ask if we can desugar this into some form of functions. Note that this is different from mutation(because differentiation is necessary)

Distributed sync prImitives and timeout can likely be implmement via special core operator(like add and sub) and there should be no problem handling this: making things more sideeffect free will make distributed parts easier

jroesch commented 6 years ago

Frontend

@tqchen and @junrushao1994 in some of other experimentation we have worked on rewriting imperative Python code into the IR directly. This should allow users to write imperative looking code which is translated into a form that is easy to optimize and perform automatic differentiation on. In general we could expose an API which looks destructive to the user but is actually pure under the covers.

Side effects, e.g. I/O and random monad

I would like to track effects as an operator attribute so we can use this information during optimization. Roughly we will respect effects when transforming programs, for example no reordering over I/O or stateful computations.

zheng-da commented 6 years ago

@tqchen, as you mentioned transformer, I'm curious how Relay is going to handle this kind of workloads where shapes changes at runtime and can't be statically inferred. Is there a plan to support it in TVM and Relay?

junrushao commented 6 years ago

@zheng-da in the standard way.

@tqchen, as you mentioned transformer, I'm curious how Relay is going to handle this kind of workloads where shapes changes at runtime and can't be statically inferred. Is there a plan to support it in TVM and Relay?

tqchen commented 6 years ago

TVM support variable length input via symbolic variable, so in theory we could build op that takes in input shape (n, 128) where n is a symbolic variable. Relay also adopt this in type system that allows handle cases of fixed dimension but symbolic shape. How to do generic code gen is another question that we can followup, but the IR itself can handle shape inference of this kind

junrushao commented 6 years ago

@jroesch I see. That looks cool.

Frontend

@tqchen and @junrushao1994 in some of other experimentation we have worked on rewriting imperative Python code into the IR directly. This should allow users to write imperative looking code which is translated into a form that is easy to optimize and perform automatic differentiation on. In general we could expose an API which looks destructive to the user but is actually pure under the covers.

Side effects, e.g. I/O and random monad

I would like to track effects as an operator attribute so we can use this information during optimization. Roughly we will respect effects when transforming programs, for example no reordering over I/O or stateful computations.

zheng-da commented 6 years ago

Do I understand it correct?

Suppose I have an operator that increases the size of all dimensions of the output by 1 compared with the input, i.e., if the input shape is (x, y), the output shape will be (x+1, y+1). Now if one of the input shape is unknown, e.g., (n, 128), as you mentioned, the output shape will be inferred as (n+1, 129). In other words, the system can infer the size of all known dimensions and capture the relation of the unknown dimensions.

Now, I have this operator op described as above (it increases the size of all dimensions by 1) and use it in a while loop as follows:

i = 0
data = mx.nd.random.uniform((10, 10))
out = mx.nd.zeros((10, 10))
while (sum(data[i]) > 1):
    out = op(out)
    i = i + 1

Can the shape of out still be inferred, maybe expressed in a symbolic way?

yzhliu commented 6 years ago

+1 for incorporating context or target into the type system, so that it can directly support heterogeneous runtime.

Shall we provide an approach to convert RelayIR to graph representation (if it can)? I'm thinking about passing subset to accelerators like TensorRT.

junrushao commented 6 years ago

@zheng-da If we convert your code to a functional one, yes, this is called shape generic?

Do I understand it correct?

Suppose I have an operator that increases the size of all dimensions of the output by 1 compared with the input, i.e., if the input shape is (x, y), the output shape will be (x+1, y+1). Now if one of the input shape is unknown, e.g., (n, 128), as you mentioned, the output shape will be inferred as (n+1, 129). In other words, the system can infer the size of all known dimensions and capture the relation of the unknown dimensions.

Now, I have this operator op described as above (it increases the size of all dimensions by 1) and use it in a while loop as follows:

i = 0
data = mx.nd.random.uniform((10, 10))
out = mx.nd.zeros((10, 10))
while (sum(data[i]) > 1):
    out = op(out)
    i = i + 1

Can the shape of out still be inferred, maybe expressed in a symbolic way?

zheng-da commented 6 years ago

@junrushao1994 I don't know what it means. Could you show me how the shape of out is expressed? Also, how is the shape used after it's inferred?

junrushao commented 6 years ago

@zheng-da Sorry for making you confused. There are two steps, the first step is to convert the code to a purely functional one, which means you use pattern matching + recursion to substitute the loop. The second step is to look at the function, then you will see the function that represents the loop body is generic.

@junrushao1994 I don't know what it means. Could you show me how the shape of out is expressed? Also, how is the shape used after it's inferred?

junrushao commented 6 years ago

@zheng-da It is called MGU (@jroesch correct me if it is not). The shape of out is called IncompleteType somewhere in the code. (I briefly glanced through the code, but didn't remember the exact name) A simple union-find set perfectly solve this problem.

@junrushao1994 I don't know what it means. Could you show me how the shape of out is expressed? Also, how is the shape used after it's inferred?

tqchen commented 6 years ago

In my unstanding, type inference was designed to understand things in compile time, so in the case of random and dimension expansion,it is impossible to decide the final dimension and inference will likely return things like incomplete type, or use shape of node that defers things to runtime. The fixed dimension symbolic shape case was the most common one that we can still take benefit from such static info

zheng-da commented 6 years ago

Agreed. The fixed-dimension symbolic shape is very useful. I think mxnet can greatly benefit from it. Could you point me to the code in TVM that does it?

I think my original question was whether there is a plan of supporting the case that the shape really can't be inferred in TVM and relay? For example, in mxnet, I'm thinking of doing something like this: https://github.com/apache/incubator-mxnet/pull/12400. Of course, my solution is hacky.

tqchen commented 6 years ago

Supporting runtime dynamic shape is really a problem of the runtime, since in terms of high level ir, we can always introduce a type(incomplete or dynamic) that indicate it can be anything. And we will need to add JIT in runtime to build and run the stages when we have more information

junrushao commented 6 years ago

@tqchen Don't think it is a big deal for the runtime if we support only PackedFunc wrapping libs like cuDNN. Many passes could be represented using a sparsely conditional constant prop, or other very mature compiler techniques.

However, if we want auto-tuning in such scenario, it could be cutting edge research.

Supporting runtime dynamic shape is really a problem of the runtime, since in terms of high level ir, we can always introduce a type(incomplete or dynamic) that indicate it can be anything. And we will need to add JIT in runtime to build and run the stages when we have more information

tqchen commented 6 years ago

To follow up on @yzhliu 's comment on whether ctx should be part of type system.

We don't have to enforce everything as part of type in order to do such optimization. Context assignments(or machine assignments) in distributed setting can also be presented in column meta data(like in NNVM). We have quite a lot of cases like this: alterative data layout, distributed machine assignements etc.

The possible pros/cons of the type system vs the additional metadata are

So in the case of shape vs context

Because of this reason, we can argue that context is preferred not as part of the type, but more like a metadata of say function or a call.

tqchen commented 6 years ago

@masahi I think in here low-level IR refers to the tensor expression part of TVM, including autoTVM, topi, compute primitives.

masahi commented 6 years ago

@tqchen thanks, makes sense. Those are in turn based on HalideIR, so in some sense HalideIR is the foundation for everything.

zheng-da commented 6 years ago

@tqchen Another question is how to integrate with some backward libraries in Relay. Maybe this isn't really a Relay question, but it's something we need to consider after TVM moves to Relay. I suppose Relay is good at pattern matching. Is it easy to take out the matched pattern and put it somewhere (maybe in an operator) to invoke TensorRT? How do you think about supporting stateful operators, both from the perspective of Relay and TVM? Having a stateful operator may be easier for us to integrate with TensorRT.

junrushao commented 6 years ago

@zheng-da Can you be more specific about “good at pattern matching”?

@tqchen Another question is how to integrate with some backward libraries in Relay. Maybe this isn't really a Relay question, but it's something we need to consider after TVM moves to Relay. I suppose Relay is good at pattern matching. Is it easy to take out the matched pattern and put it somewhere (maybe in an operator) to invoke TensorRT? How do you think about supporting stateful operators, both from the perspective of Relay and TVM? Having a stateful operator may be easier for us to integrate with TensorRT.

zheng-da commented 6 years ago

@junrushao1994 actually, I don't know. I guess Relay should be able to do pattern matching. One example of pattern matching is to find a set of operators that can be fused in TVM.

masahi commented 6 years ago

@zheng-da NNVM can already do operator fusion. TVM supports cuDNN offload out of the box. This tutorial maybe helpful.

yzhliu commented 6 years ago

I think Da's asking similar question as mine, since TensorRT eats a graph, if we extract a subset of Relay IR and pass it to TensorRT for accelerating, then an intermediate graph representation is required.

@zheng-da I don't think it is a problem to find out operators meet a specific requirement - we can still traverse the Relay AST.

junrushao commented 6 years ago

@zheng-da I got what you mean. I was confused because “pattern matching” typically refers to conditional statements in FP.

I think @zheng-da is right. This is not a Relay question. I would say it is pretty trivial to do by traversing the IR and extract whatever TensorRT supports,as @yzhliu suggests. However, there are two things we should keep an eye on: 1) I am not a big fan of a stateful operator (and I believe nobody is). We should try to separate their state out as the arguments, and return the new states back after computation. 2) We should be careful with infinite recursion.

@junrushao1994 actually, I don't know. I guess Relay should be able to do pattern matching. One example of pattern matching is to find a set of operators that can be fused in TVM.

junrushao commented 6 years ago

@yzhliu I think what you are trying to say is “we need an IR converter”. Yes, that makes a lot of sense. Ideally, there should be a bridge IR to which every DL framework converts themselves, and from which the low-level lib provider writes a pass to convert to their own IR. But such thing does not seem to exist yet (or ONNX?).

We could definitely make it more systematic though, but it does not seem that necessary for now, because only a very small number of low-level lib consumes a graph.

As for integrating operator-level libraries like cuDNN, this is never a problem in TVM...

I think Da's asking similar question as mine, since TensorRT eats a graph, if we extract a subset of Relay IR and pass it to TensorRT for accelerating, then an intermediate graph representation is required.

@zheng-da I don't think it is a problem to find out operators meet a specific requirement - we can still traverse the Relay AST.

junrushao commented 6 years ago

@tqchen @jroesch Let's talk about introducing data structures like lists, dict. (I re-structured the stuff I wrote yesterday)

Solution A: use confluently persistent data structure. (immutable) There is theoretically no difficulty in implementing functional data structures like functional List or functional Map. But we may have two concerns: 1) overhead for being functional. For example, I often write the functional Map using merge/split Treap, which maintains an expected O(log n) time & space complexity, but my C++ implementation is likely to be 4x - 10x slower than non-functional ones. If we have a good runtime, the latency should somewhat be hidden, but I am not sure. 2) Another thing is if we introduce immutable lists / dicts, the semantics of list and map would be changed. But of course, we could force users use things like tvm.immutable_list, and declare "we don't want you guys to use Python list/map".

Solution B: let's leave it dirty. (mutable) There are also two concerns. 1) In my opinion, we should try to avoid any tracing-based autodiff, otherwise it loses both the elegancy of our design and part of the meaning of being functional. 2) It leaks side effect everywhere. To avoid this, it is possible to slice the graph. But it is relatively bad idea especially when it is inside a loop (see the self-attention example), because it discourages a lot of global optimizations.

junrushao commented 6 years ago

Part II: Optimization

I feel that this part is tightly entangled with runtime though...

Terms

Optimization for deep learning kernels (vectorized) seems totally different from that for scalar operations (e.g. +, *, /, ->, etc). I am not an expert, but please allow me to define two terms to distinguish these two kind of optimizations:

Note that kernel optimization is a superset of scalar optimization by definition, but here let's assume kernel optimization refers to the part that is not scalar.

Memory for high-level, speed for low-level

As far as I could tell,

However, by prioritize instructions, optimization in speed is also possible, when control flow exists or in the distributed setting.

Memory I: Liveness analysis, memory preallocation and sharing

One good thing for a purely FP is liveness analysis is pretty trivial. I would prefer just to add a tag DEAD(some-memory-chunk) to inform the runtime.

There are several situations we need to consider: 1) An ordinary function, and no other functions are its arguments, no infinite recursion, no pattern matching: Memory preallocation is directly doable once all shapes are known. 2) An ordinary function, and no other functions are its arguments, no infinite recursion, but has pattern matching: This creates exponential number of combinations of memory footprint. We could build a bin estimator statistically analyzing the memory footprint. 3) A higher-order function, or a function could trigger self/mutable recursion: it is hard, let's not do such optimization for this kind of function.

Also, memory sharing is trivial across functions.

Memory II: Host memory as $L_\infty$ cache

This is inspired by @tqchen's paper [1], but in a less elegant way. I also prefer to leave this to the runtime.

Memory III: Detecting contiguous memory allocation

From the NLP community, it has raised lots of concern to detect contiguous memory allocation. Again, I would love to take self-attention as the example, which is widely used in generating sequences of better quality in machine translation, question answering, etc.

def self-attentive-generator():
  initialize states
  initialize outputs = []
  while True:
    prev_outputs = concat(outputs)
    context_vector = self-attention(prev_outputs, inputs, state)
    step_output, states = DecoderStep(states, prev_outputs)
    outputs.append(step_output)
    if some_condition:
      break
  outputs = concat(outputs)
  return outputs, states

At each step, we have a concat, which produces a new chunk of memory. Unfortunately, this memory could not be released, because we probably are going to back-propagate through it. This causes quadratic memory consumption.

I think it is worth mentioning that it is not a corner case, but seems to me a trend that many practitioners now knows it works, and they wants to add this seemingly free lunch to their sentence generation model.

It is possible to optimize this, as long as we could detect this memory allocate pattern. Fortunately, yes, it is doable in this IR.

Memory IV: (hard case) optimize shape-related operations following Numpy

This optimization is pretty useful, but not easy to implement under the current IR.

For example, a rank-n Tensor is np.swapaxes, and subsequently a reshape is called. Numpy will check if it is possible not to do zero-copy in reshape. However, it seems that we did not expose such thing in current level of IR because it requires ndarray.strides.

@tqchen As I suggested in my proposal, we should expose everything in some level of IR.

Speed I: LLVM/GCC does much better than any single person

Scalars are currently represented using rank-0 tensors, which is unified under current IR. However, it remains a question to me whether we want to do this by like launching a kernel to some GPU stream, or MKLDNN stream, then waiting for callback, merely for computing scalars like a + b.

I would propose to do program slicing in the lower-level IR to distinguish scalar and kernel operations, then grant scalar operations to LLVM RTC for optimization.

Speed II: Decoupling scalar, memory and kernel operation

This is similar to the previous section, that in a lower-level runtime, we would decouple these three things into 3 program counters, do speculative execution over some PCs, in order to let kernel operations launch continuously, in this case, launching gap will be fully eliminated.

Speed III: Register allocation

I don't think it is doable in current version of TVM, but it could bring significant speed improvement when viable. For example, fixing weights of a RNN cell into some registers, make sure they aren't spilled.

This optimization is also related to pipelining instructions layer by layer in inference. Let's leave it to f future work...

[1] Chen, T., Xu, B., Zhang, C., & Guestrin, C. (2016). Training deep nets with sublinear memory cost. arXiv preprint arXiv:1604.06174.

MarisaKirisame commented 6 years ago

@junrushao1994 I am the main designer/implementer of relay's automatic differentiation system.

In general, doing non-tracing based reverse mode automatic differentiation on arbitrary lambda is extremely hard. There is only one paper (Reverse Mode AD in a functional framework) that does it, which work by traversing the reflected program, is complicated, and is untyped.

We might be able to type it, but it will bring a huge source of complexity, and optimizing on reflection will not be easier then optimizing trace. So, we actually use a tracing based approach, which is very similar to (Demystifying Differentiable Programming), except we do not use continuation, only Mutation.

IMO as there is already effect everywhere (random, IO in reinforcement learning, mutation in NLP, distributed training) etc, the problem is less of 'whether there should be effect or not', and is more 'how should we capture effect? Monad or Eff-like effect system or doesnt at all (as in OCaml/SML), only in static analysis?' I do agree that it is a problem in it's own right, but I think some notion of effect is inevitable.

Back to your particular problem, I think there is a 'best of both world solution'. introduce a type Ref a. It mean a pointer pointing to a, which can change it's content. the pointer cannot change what it point to though (albeit it can be achieved with Ref(Ref a)). There is 3 function on Ref. MkRef : a -> Ref a GetRef : Ref a -> a SetRef : Ref a -> a -> () and possibly, updateRef : Ref a -> (a -> a) -> (), which is atomic. introduce effectless list/dict. translate python list into Ref(List a). in the compiler, add special hook for Ref(List a), and use custom mutable datastructure. we can also change list a to mutable one(in compiler) if it is not being shared.

I think this address the (1, 2) in solution A, and the previous paragraph address (1) in solution B.

Let's talk about (2) B. I do agree that reference hinder optimization. However, so does reflection - which is the only other way for higher order reverse mode differentiation on higher order function. I also postulate that with constant folding, the reference can be optimized away when the control flow is known. It will only exists at the boundary of unknown function call. If some variable are only used locally, never leaked outside, and their usage does not vary to the control flow, they should not generate Ref.

Of course, it is only a postulation at this point, but we also has a first order reverse mode automatic differentiation algorithm implemented, with no wengert list at runtime. The down side is that it does not work with control flow. We can always add special case to make sure no Ref is generated here, to achieve better speed.

Also IMHO we are pondering into the future too far ahead. AFAIK No one know how will reference, data-structure, tensor, ad, play together, when we try to compile efficient code on GPU. I think we should hold such design decision until much later phase, when we have a clearer picture.

Does I address your question?

junrushao commented 6 years ago

@tqchen @jroesch I guess I have done the optimization part. I believe most of the optimization techniques for deep learning workloads that I could up with could be somehow covered by using in RelayIR, thanks to the purity of FP. Could you guys give some feedback about this?

I am really interested in how you guys could implement a low-level runtime environment, because I didn't see the design in the Relay paper. Could you guys kindly share some information with me?

tqchen commented 6 years ago

WRt to optimizations, we should try to push most optimization to compiler and leave runtime lightweight.

Scalar slicing and fusion. This is likely can be done already in relay, by slice out 0 rank tensor and generated a fused function to compute them, only one memory store is necessary and they act similarly as infer shape.

Reshape opt is not as important, usually inplace type memory optimization is not as important as long as there is memory reuse. Because the memory before reshape can get reused in the next stage. Compact memory is much better for speed optimizations

Concat, directly add to slice

This has something to do with custom calling convention, when building a relay function, there are multiple ways on how to handle calls. The tensor space can be caller allocated or callee allocated. And we can specify if there is a fused op for return value. For example we can support customized calling convention, like invoke this function and add the result to a preallocated array. Combining compile time analysis with customized calling convention likely can solve this problem

junrushao commented 6 years ago

@MarisaKirisame Thank you so much for your very helpful comments! It does address most of my questions.

@junrushao1994 I am the main designer/implementer of relay's automatic differentiation system.

In general, doing non-tracing based reverse mode automatic differentiation on arbitrary lambda is extremely hard. There is only one paper (Reverse Mode AD in a functional framework) that does it, which work by traversing the reflected program, is complicated, and is untyped.

We might be able to type it, but it will bring a huge source of complexity, and optimizing on reflection will not be easier then optimizing trace. So, we actually use a tracing based approach, which is very similar to , except we do not use continuation, only Mutation.

IMO as there is already effect everywhere (random, IO in reinforcement learning, mutation in NLP, distributed training) etc, the problem is less of 'whether there should be effect or not', and is more 'how should we capture effect? Monad or Eff-like effect system or doesnt at all (as in OCaml/SML), only in static analysis?' I do agree that it is a problem in it's own right, but I think some notion of effect is inevitable.

Back to your particular problem, I think there is a 'best of both world solution'. introduce a type Ref. It mean a pointer pointing to a, which can change it's content. the pointer cannot change what it point to though (albeit it can be achieved with Ref). introduce effectless list/dict. translate python list into Ref. in the compiler, add special hook for Ref, and use custom mutable datastructure. we can also change list a to mutable one(in compiler) if it is not being shared.

I think this address the (1, 2) in solution A, and the previous paragraph address (1) in solution B.

Let's talk about (2) B. I do agree that reference hinder optimization. However, so does reflection - which is the only other way for higher order reverse mode differentiation on higher order function. I also postulate that with constant folding, the reference can be optimized away when the control flow is known. It will only exists at the boundary of unknown function call. If some variable are only used locally, never leaked outside, and their usage does not vary to the control flow, they should not generate Ref.

Of course, it is only a postulation at this point, but we also has a first order reverse mode automatic differentiation algorithm implemented, with no wengert list at runtime. The down side is that it does not work with control flow. We can always add special case to make sure no Ref is generated here, to achieve better speed.

Also IMHO we are pondering into the future too far ahead. AFAIK No one know how will reference, data-structure, tensor, ad, play together, when we try to compile efficient code on GPU. I think we should hold such design decision until much later phase, when we have a clearer picture.

Does I address your question?

tqchen commented 6 years ago

One thing to keep in mind is that we need to codesign things, instead of simply think in terms of high level ir, for example, the function slicing likely have things to do with what low level code generator and hw has to offer. So most low level info need to be registered in high level and we need embed meta data to reflect certain info

zhiics commented 6 years ago

@tqchen I also totally agree that we should give most of optimization work to the compiler and keep runtime light. But I actually also have a concern about the optimization passes which is probably more related to NNVM. There are already many passes and the number is expected to continue growing. As the number passes grows, I think it would be beneficial to have a more systematic way to manager them.

For example, I am thinking if it makes sense to introduce something like a "PassManager" (as in llvm) to maintain the passes. PassManager may provide the following things.

  1. It can expose some APIs to users to implement their own opt passes at different levels, e.g. optOnGraph(Graph&& g) and optOnTensor(Tenosr&& t) for optimizations on graph and tensor respectively.
  2. The passmanager might be able to help maintain or tune the sequence of different opt passes.
  3. It could reserve some analysis information for loops and/or tensors, although I am not exactly sure about what information would be necessary. This information can help more optimizaitons.
zheng-da commented 6 years ago

@junrushao1994@gmail.com junrushao1994@gmail.com I don't know many external libraries for deep learning, but I can name a few: TensorRT, nGraph, which requires graph-level integration. As far as I know, both Intel and Nvidia are developing more graph-level libraries. I think it's pretty common. I think we should think about this problem. As for stateful operators, I wonder what is the other option. If we separate the state and pass it as an input argument, the data structure (it can contain any arbitrary data required by the external libraries) might be pretty complex. It doesn't seem to me that Relay can handle this kind of data structure.

On Thu, Aug 30, 2018 at 3:39 AM Junru Shao notifications@github.com wrote:

@yzhliu https://github.com/yzhliu I think what you are saying is “we need an IR converter”. Yes, that makes a lot of sense. Ideally, there should be a bridge IR that every DL framework converts their IR to the bridge IR, and the low-level lib provider writes a pass from the bridge IR to their own IR. But such thing does not seem to exist.

We could definitely make it more systematic though, but it does not seem that necessary for now, because only a very few number of low-level lib consumes a graph.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/dmlc/tvm/issues/1673#issuecomment-417274015, or mute the thread https://github.com/notifications/unsubscribe-auth/AAETUVTEugZ4-2II2ws5iodl-qukmWsgks5uV8DigaJpZM4WRCFP .

junrushao commented 6 years ago

@zheng-da these are just trivial engineering choices, and could you name a specific thing that relay could not handle? @tqchen could you comment on this?

@junrushao1994@gmail.com junrushao1994@gmail.com I don't know many external libraries for deep learning, but I can name a few: TensorRT, nGraph, which requires graph-level integration. As far as I know, both Intel and Nvidia are developing more graph-level libraries. I think it's pretty common. I think we should think about this problem. As for stateful operators, I wonder what is the other option. If we separate the state and pass it as an input argument, the data structure (it can contain any arbitrary data required by the external libraries) might be pretty complex. It doesn't seem to me that Relay can handle this kind of data structure.

kevinthesun commented 6 years ago

This is a good chance to look at data layout system. I think @yzhliu is currently working on refactoring layout in TVM: https://discuss.tvm.ai/t/datalayout-structure/80

To enable graph level optimization, every operator will require layout information. Maybe we can considering adding it Relay type system.

zheng-da commented 6 years ago

@junrushao1994 when i look at the type system (in the Relay paper), it supports Base type, shape, Tensor, function, type, reference, tuple. Do you suggest representing the data structure for any arbitrary external library with the Relay type system? For example, MKLDNN requires some data structure like mkldnn::memory::primitive_desc. It's a class that contains std shared_ptr. It's probably doable to store this data structure in Relay, but it might be more convenient to support something like OpaqueType for arbitrary operator states.

The other problem is that these external libraries may change the state after each invocation. However, we don't know if they really change or how they change the states. Therefore, the operator can't be pure functional. Does Relay need to deal with it?

junrushao commented 6 years ago

@zheng-da Relay has some notion to track effects, so why not you guys put these arbitrary stuff inside something like a PackedFunc?

Update: as @MarisaKirisame mentioned, I am wrong. Please just ignore this reply.

@junrushao1994 when i look at the type system (in the Relay paper), it supports Base type, shape, Tensor, function, type, reference, tuple. Do you suggest representing the data structure for any arbitrary external library with the Relay type system? For example, MKLDNN requires some data structure like mkldnn::memory::primitive_desc. It's a class that contains std shared_ptr. It's probably doable to store this data structure in Relay, but it might be more convenient to support something like OpaqueType for arbitrary operator states.

The other problem is that these external libraries may change the state after each invocation. However, we don't know if they really change or how they change the states. Therefore, the operator can't be pure functional. Does Relay need to deal with it?

junrushao commented 6 years ago

@kevinthesun Hey Yao, could you kindly share more thoughts about what information you think must be put into the type system? It will be very helpful!

This is a good chance to look at data layout system. I think @yzhliu is currently working on refactoring layout in TVM: https://discuss.tvm.ai/t/datalayout-structure/80

To enable graph level optimization, every operator will require layout information. Maybe we can considering adding it Relay type system.

MarisaKirisame commented 6 years ago

@junrushao1994 we hasnt settle on any effect system yet. We would love to know what actual use case is there!

junrushao commented 6 years ago

@MarisaKirisame This is the concern raised from colleagues working on external library integration, as mentioned by @zheng-da.

@junrushao1994 we hasnt settle on any effect system yet. We would love to know what actual use case is there!

kevinthesun commented 6 years ago

@junrushao1994 For NNVM/Relay, the layout information is mainly used to insert layout transformation op when necessary. Currently this is achieved by FCorrectLayout attribute.It's like an "InferLayout" attr. We might want to preserve the latest valid layout of each op, so that we can easily fall back to last valid layout when the new layout pass is illegal for some ops. The logic should be similar to current NNVM implementation, but we might be able to better manage it in Relay.