tlc-pack / relax

Apache License 2.0
193 stars 58 forks source link

[DISCUSS] Relax Pass Infrastructure #71

Open sunggg opened 2 years ago

sunggg commented 2 years ago

Like we briefly discussed durning the open development meeting, I think it would be great to start our brainstorming what we want to enable in Relax. I sketched some of my thoughts on existing approaches, so feel free to add comments if you have yours or any questions. I will put them together and bring to our next discussion meetings so that we can address those issues to our initial design.

Motivation

Recent studies demonstrate that feedback from low-level information can substantially help various high-level decisions. For example, TASO searches for the best form of the computation graph by exploring various graph rewriting rules (e.g., layout transformation) with such feedback. Also, Collage finds the most efficient multi-backend execution strategy in a feedback-directed fashion (related discussion: https://github.com/tlc-pack/relax/issues/46). There are various studies regarding flag-level tuning approaches as well.

However, conventional pass infrastructure is designed based on the idea of progressive lowering and it cannot provide seamless integration of these tuning approaches. This prevents from the adoption of various tuning approaches across the different abstraction layers and their joint optimization opportunities. (e.g., TASO+Collage)

Thus, we want to open up new opportunity by offering natural integration of tuning approaches with new pass infrastructure design. Please note that we will separate these optimization passes from the build (related discussion: https://github.com/tlc-pack/relax/issues/49).

Existing Approaches

comaniac commented 2 years ago

We are actually facing a similar issue that may require joint optimization between passes in training. For example, AutoCast (AMP) may choose to insert a new cast op or reuse a already inserted on. Reusing cast ops could maintain the execution order, but it means the output of one cast op is used by more than two subsequent ops, which prevents fusion from happening. In addition, we also have an issue about which backend should be used for each op. As a result, we are thinking an infra for joint optimization, including AutoCast, Fusion, DialectOpLowering, and Rematerialization.

In terms of Google XTAT. Some thoughts about your comments:

  1. I agree that per node configuration seems not flexible enough. As mentioned in Section II.D, using XTAT as a subroutine could workaround this limitation, but it also means the pass developer has to consider a larger scope other than just the pass itself. However, an advantage of node configuration I could think of is to enable partial tuning, which may save lots of time.
  2. I actually don't against sequential pass application, because some passes are implemented with an assumption that the IR has to be XX in advance. XX can be like "fused", "simplified", "canonicalized", "dead code eliminated", etc. IMHO, as long as we have the tuning process (i.e., joint optimization) so that early applied passes could still make good decisions for later passes, fixing the order seems not an issue and could largely simplify the design (for both developing an infra and passes).
  3. This echos previous two points about the developer experience, and this is the most important part to me.

A vague approach in my mind is having a configuration similar to AutoTVM. Specifically, developers could use the tuning APIs to represent tunable parameters in a pass, and the tuning infra is able to collect them and figure out the best combination. Pros: developers can optionally add a dimension of tuning configurations without worrying about anything else. Cons: 1) the tuning infra basically has no idea about what it is tuning, so it's hard to improve the tuning efficiency from the search algorithms. 2) the configuration scope is per pass, making it as the smallest granularity.

Just my two cents. The approach I mentioned is definitely not optimal, but hopefully this could let others chime in for better ideas.

sunggg commented 2 years ago

@comaniac, thank you for your input and I totally agree with your thoughts. I expect there would be more exciting opportunities in training since there are more interesting operations and data coming into play, like you gave examples. Regarding sequential pass application, I also think it should be one of the fundamental principles in our new infra. However, it will be great if we can allow some opportunity for "true" joint-optimization. I might have an idea, but let me polish it and bring it to our discussion with other's inputs in this thread.

sunggg commented 2 years ago

Hi, all! Hope you are all doing well. Since our last discussion, I've worked on identifying more concrete challenges in the current pass infra and drafted an initial design to address them. Let me start with the summary of challenges I found.

Major Challenges in Current Pass Infra

Thus, we want to address these challenges with new pass infrastructure design.

Termonology

We define two kinds of optimization or analysis passes.

Design Goal

New pass infrastructure design aims to provide the flexible and composable optimization pipeline with following goals.

Design Principles

We propose the following design principles:

These principles will enable followings. (H: heuristic pass, T: tuning pass, eval_pass: passes for the candidate evaluation)

Synergy with Relax

New pass infrastructure will have unique synergy with Relax. Since Relax aims to express different abstraction layers in IRModule and offer a universal build for such IRModule, it will allow various opportunities to explore such as,

Class Design

Please note that this is a pseudo code. For actual implementation, we may extend the existing data structures and functionalities if possible. For example, given that MetaSchedule provides a great fundamental APIs, such as builder/runner, database API, cost model, etc. for generic tuning methods, @junrushao1994, @zxybazh and I are planning to discuss how we can extend it for the tuning pass with generic tuning methods beyond kernel-level tuning. Also, currently, there are different functions depending on whether you are handling Function or IRModule . Their handling is omitted for simplicity.

Pass Class

# Base pass class
class Pass():
    # Specify dependent passes
    def __init__(self, required=[]):
        self.required = required

# Pass context class
class PassContext():
    def __init__(
         self,
         # ... include current fields ...
         target, # this is necessary for evaluation in tuning pass.  
     ):
        # ...

# Base class for heuristic pass
# It will look similar to current pass design.
class HeuristicPass(Pass):
    # Actual implementation for optimizations/analysis
    def transform_module(
           self, 
           mod: IRModule, 
           ctx: PassContext
         )->IRModule:
        # ... contents ...

# Base class for tuning pass
class TuningPass(Pass):
    def __init__(self, eval_passes, eval_metric, measure_option):
        super().__init__()
        # Passes for evaluation to enable joint-optimization
        self.eval_passes = eval_passes
        # Evaluation criteria for candidates (e.g., execution time)
        self.eval_metric = eval_metric
        # Measurement option
        self.measure_option = measure_option

    # Use metaschedule for evaluation
    def evaluate(self, ctx, candidates):
        target = ctx.config["target"]
        # Evaluation
        scoreboard = {}
        for candidate in candidates:
            # Apply pass group before build
            seq = tvm.transform.Sequential(self.eval_passes)     
            candidate = seq(candidate)

            # Leverage metaschedule builder/runner to get score
            score = ...
            scoreboard[candidate] = score
        return scoreboard

    # Different tuning methods may have different cost model. Can we extend metaschedule?
    @staticmethod
    def query_cost_model(candidates):
        pass

    # Can we extend metaschedule cost model?
    @staticmethod
    def update_database(...):
        pass

    @staticmethod
    def select_best_candidate(scoreboard):
        # ... select the best candidate depending on the eval_metric ...
        return best_candidate

     # Actual implementation for optimizations/analysis
     # This will have feedback loops following repeating steps
     # (1) candidate generations
     # (2) candidate evaluation: this will call evaluate() method
     # (3) pick the best candidate and reflect feedback  
     def transform_module(
           self, 
           mod: IRModule, 
           ctx: PassContext
         )->IRModule:
        # ... contents ...

# Some useful APIs
# Sanity check for IRModule
def validate(mod: IRModule) -> Bool:
  # ... validataion logic ...

# Dependency check for a sequence of passes
def validate(seq: Array<IRModule>) -> Bool:
  # ... validation logic ...

# Extract certain part of graph in interest
# This will be useful for subgraph benchmarking
def extract_subgraph(mod: Expr) -> Expr:
  # ... extraction mechanics ... 

[TBD] Data structure for the communication with the build system.

This requires a discussion. Look at D1 for details. This will define the exploration space for optimization passes. (e.g., can we explore fusion decisions?)

[TBD] Pass sequence registration interface.

This requires a discussion. Look at D5 for details.

Developer PoV

Developers can design their own custom passes and perform any IRModule -> IRModule transformation. As an example, we can design simple mock tuning passes that decides whether to apply certain heuristic pass based on the low-level feedback.

# This mock tunining pass fuses parallel matmul
@ir.transform.module_pass(opt_level=1)
class TuningCombineParallelMatmul(TuningPass):
    def __init__(self, eval_passes = []):
        super().__init__(eval_passes)

    def transform_module(
                  self, 
                  mod: IRModule, 
                  ctx: PassContext)->IRModule:
        # Candidate generation
        new_mod = transform.CombineParallelMatmul()(mod)
        # Two candiate: Do you want to enable it? or disable it?
        candidate_pool = [mod, new_mod]
        scoreboard = self.evaluate(ctx, candidate_pool)
        best_perf, best_mod = self.select_best_candidate(scoreboard)
        return best_mod

# This mock tunining pass makes layout transform decision  
@ir.transform.module_pass(opt_level=1)
class TuningLayout(TuningPass):
    def __init__(self, eval_passes = []):
        super().__init__(eval_passes)

    def transform_module(self, mod: IRModule, ctx: PassContext)->IRModule:
        # Candidate generation
        new_mod = transform.LayoutTransform()(mod)
        # Two candiate: Do you want to enable it? or disable it?
        candidate_pool = [mod, new_mod]
        scoreboard = self.evaluate(ctx, candidate_pool)
        best_perf, best_mod = self.select_best_candidate(scoreboard)
        return best_mod

Depending on what developers want, they can run each pass separately, in sequential, or in joint-optimization fashion.

# Run TuningLayout pass only  
custom_pass = TuningLayout()
optimized_mod = custom_pass(mod)

# Run TuningLayout and TuniningCombineParallelMatmul sequentially
# You can also change the order easily
seq = [ TuningLayout(), TuningCombineParallelMatmul() ]
custom_pipeline = tvm.transform.Sequential(seq)     
optimized_mod = custom_pass(mod)

# Run joint-optimization
seq = [ TuningLayout(eval_passes = [TuningCombineParallelMatmul()]) ]
custom_pipeline = tvm.transform.Sequential(seq)     
optimized_mod = custom_pass(mod)

# Later, you can generate executable by using relax build system.
lib = relax.vm.build(optimized_mod)

Developers can also design a tuning pass for more interesting optimization decisions if the build system supports (e.g., fusion decision, lowering decisions, graph rewriting). Like previous example, they can also easily test out joint-optimization.

# TASO-like tuning pass
@ir.transform.module_pass(opt_level=1)
class TuningGraphRewriting(TuningPass):
    def __init__(self, eval_passes = []):
        super().__init__(eval_passes)

    def transform_module(
         self, 
         mod: IRModule, 
         ctx: PassContext)->IRModule:
          # Some tuning approaches repeat search within their tuning budget
          budget = ...
          while budget>0:
                # ... some analysis ...
            # (1) Generate candidates
            candidates = get_promising_rewritten_graph(expr) 
            # (2) Evaluate candidates with other passes and minimal build
            scoreboard = self.evaluate(ctx, candidate_pool)
            # ...
            # (3) Reflect the feedback
            best_perf, best_mod = self.select_best_candidate(scoreboard)
            # ... generate next promising candidates based on the current feedback ...
               return optimzied_mod

# Collage-like tuning pass
@ir.transform.module_pass(opt_level=1)
class TuningBackendPlacement(TuningPass):
    def __init__(self, eval_passes = []):
        super().__init__(eval_passes)

    def transform_module(
         self, 
         mod: IRModule, 
         ctx: PassContext)->IRModule:
              # ... some analysis ...
          for node in post_order_dfs(mod):
              # ... some analysis ...
          # (1) Generate candidates
          candidates = get_available_backend_candidates(node) 
          # (2) Evaluate candidates with other passes and minimal build
          scoreboard = self.evaluate(ctx, candidate_pool)
          # ...
          # (3) Reflect the feedback
         best_perf, best_mod = self.select_best_candidate(scoreboard)
         annoate(node, best_candidate) 
         # ...
    return optmized_mod

If you are interested in, you can also play around with its prototype in Relay world here: link

Discussion Points

Any feedback or thoughts would be greatly appreciated! Since this might be a bold design, I would like to face potential issues early.

sunggg commented 2 years ago

Thanks for valuable inputs today! And definitely I would appreciate more feedback if you have any. Feel free to leave more :)

Some of feedback during the meeting

  1. Use program feature for pass application - link
  2. Dependency checker
  3. Some tuning passes may require their own parameters - how can we systemically pass them?
  4. Better design to explore the order of phases - can we make it more intuitive and organic?
  5. Investigate the potentially related work for tuning pass - MLIR? XLA?
ZihengJiang commented 2 years ago

Hi @sunggg , thanks for the great proposal. I have a few questions:

sunggg commented 2 years ago

Hi, @ZihengJiang

  • For the tuning pass with eval passes usage T1(eval_passes=[H3, H4]), will the H3 and H4 happen before or after T1?

eval_passes will apply for candidate evaluation. Each candidate will be evaluated after applying the given passes. example

  • If we want to do joint-optimization with several tunable passes, each tunable pass has its own eval passes, how can we represent this with current API?

Currently, I'm thinking to allow their own eval passes and trying to find potential issues. If you find any corner case, that would be very helpful. Current design specifies eval pass on pass invocation. If this is not what you asked for, would you elaborate a little further?

class TuningPass(Pass):
   def __init__(self, eval_passes, eval_metric, measure_option):
        super().__init__()
        # Passes for evaluation to enable joint-optimization
        self.eval_passes = eval_passes
hypercubestart commented 2 years ago

hi @sunggg , thanks for the great work! Some random ideas/thoughts about pass order infrastructure:

sunggg commented 2 years ago

@hypercubestart, thank you for your input!

  • In Relay, quantization workflow is split into 3 separate passes: QuantizeAnnotate, QuantizeCalibrate, QuantizeRealize. For QuantizeCalibrate, we need to represent required pre-passes (annotate) and post-passes (realize), and in addition the passes must be run one after the other. This might challenging to represent in the current architecture unless all three quantize passes are combined into a single pass

Yes, great point. I think I may miss in the design in the above, but I'm planning to add required_passes in base Pass class and make sure the given sequence satisfies such constraints. Like we briefly discussed during meeting, I'm considering to provide this information on pass instantiation like we do for eval_pass for more flexible configuration, rather than hard-coding like we do in the current infra. (e.g., the same optimization may need different analysis passes depending on whether it is static model or dynamic model). I will make this clear in the next draft.

  • also follow-up to @ZihengJiang 's question about joint-optimization. If I had two passes LayoutTransform (TIR)/LayoutRewrite(Graph), and wanted to joint-optimize, does this require rewriting a completely new tunable pass? If so, how would this interact with passes that may expect to require LayoutTransform but not the new pass?

Theoretically, if you want to joint-optimize two independent passes, you would just include one pass in the eval pass of another one. e.g., seq = [T1(eval_passes=T2)] If you are 100% sure that those two passes don't need to be separated, you may consider designing a pass that combines both passes. However, although you introduce this new tuning pass, you don't necessarily need to remove the existing passes because other passes may depend on them like you described. And of course, since this may introduce some overlaps between passes, I think we need to carefully think what passes we want to provide by default while providing all customization support as an infrastructure.

And I think your example brought up excellent point. If two passes may do similar jobs in the different abstraction layers like LayoutTransform (TIR)/LayoutRewrite(Graph), later one may revert earlier one's decision. I believe this is an open question, so I added it as one of the discussion point D3 in my proposal.

slyubomirsky commented 2 years ago

A brief thought I brought up in our meeting but wanted to leave in writing: For the management of passes and when they are applicable, we may want some automatic way of specifying which program features passes are designed to handle and automatically check for them like the seldom-used feature flags in Relay. Feature flags like those (but coupled with automatic enforcement) could allow for detecting certain kinds of bugs and incompatibilities in advance. (Alternatively, we could have the norm that all passes are expected to support all program features, which would also require testing them against all program features to be certain of that.)

sunggg commented 2 years ago

Hi, all. I'd like to discuss the first formal version of Tuning Pass design. Once we reach on the agreement, I'd like to start working on the implementation.

Backgrounds

Today’s tuning methods

What you tune

How you tune

Depending on what you accept as input and how you generate candidates

Search methods - different tuning methods may favor different strategy

Depending on whether you stop tuning anytime or not, we can categorize the search into two kinds.

Goal

Design Overview

Fundamentally, tuning is a feedback-directed search that repeats three primitives: (1) candidate generation (2) candidate evaluation (3) update of search state. Therefore, as a bare minimum, a tuning pass must allow developers convenient API for these three tuning primitives. Note that a tuning pass often wants to account the effect from other optimization passes for (2), so API design should consider this. Besides, there can be optional primitives such as cost model or data-driven prediction model (takes an IRModule and predicts the optimized IRModule, potentially in a data-driven manner. ApplyHistoryBest is a representative example. AutoTVM has a scoring heuristic to find the closest workload in the database.)

heu_tuning_tradeoff

To ease the development complexity, each tuning pass would generate its search space based on its input IRModule - each pass does not need to worry what other passes are doing as long as each pass is guaranteed to produce a valid IRModule . Later, a user can customize how to apply these tuning passes when defining a sequence of a pass pipeline.

# Say search space of TuningPass1/2 is s1/s2 respectively.

# 1. Two sequential tuning passes: search space grows additively
Seqential([TuningPass1(), TuningPass2()])
# -> Search space: s1+s2

# 2. Joint optimization: search space grows combinatorially 
Seqential([TuningPass1(eval_pass=TuningPass2())])
# -> Search space: s1*s2

# Since joint-optimization expands search space rapidly, 
# it is highly recommended to apply two tuning passes in sequential
# if they are orthogonal to each other

As a tuning pass can nest other tuning passes for joint-optimization, the trace of a sequence of transformation may not be easy to track. This would cause huge difficulty in understanding its behavior and debugging. Thus, inspired by MetaSchedule, we introduce Instruction and Trace

def convert_conv2d_NHWC(mod):
    new_mod = ConvertLayout({"nn.conv2d": ["NHWC", "default"]})(new_mod)
    return new_mod

def convert_conv2d_NCHW(mod):
    new_mod = ConvertLayout({"nn.conv2d": ["NCHW", "default"]})(new_mod)
    return new_mod

def noapply(mod):
    return mod

choices = {
    "convert_conv2d_NHWC": Choice(convert_conv2d_NHWC),
    "convert_conv2d_NCHW": Choice(convert_conv2d_NCHW),
    "NoApply": Choice(noapply),
}
knob = Instruction("LayoutTransform", choices)
Trace length: 2
[1] MockTuningInst1: choice3
[2] MockTuningInst2: choice1

I want to clarify that the goal of the tuning pass is NOT to replace the heuristic passes. Each approach has its own unique strength and they can help each other to evolve.

heu_tuning_relation

Conceptually, tuning pass and heuristic pass only have difference in the decision making process while considering the same set of transformation (e.g., once fusion decision is made by either tuning or heuristic methods, the process of applying the decision would be same.) Thus, tuning pass and heuristic pass would be likely share many API functions. Our design also aims to provide such common functionalities by maximizing the code reuse.

pass_design_concept

Tuning API Design


# Tuning API provides important primitives to implement a tuning method
# Current design integrates MetaSchedule builder/runner and database
# By default, it would use the database for a pass pipeline
# However, it can also manage its own database if necessary
# Users can define eval_passes to consider the interaction with other passes

# Classes
# A choice defines a valid transformation
# Instruction will consider each choice for its candidate generation
# To reduce the search space, each choice may be considered in a probabilistic manner
class Choice:
     def __init__(self, func: Callable, constr=None, args=None):
         self.func = func       # transformation func 
                                # it allows feedback loop to cover finer 
                                # granuality of candidate tuning 
                                # (e.g., subgraph tuning, see Collage example below)
         self.constr = constr   # constraints e.g., condition on tensor shape
         self.args = args       # arguments for func

class Instruction:
     def __init__(
         self, name: str, choices: Union[List[Choice], Dict[str, Choice], Dict[int, Choice]]
     ):
         self.name = name
         self.choices = choices

     # Check if a decision is valid
     def verify(self, decision: Union[str, int]) -> Boolean:
         if isinstance(self.choices, dict):
             return decision in self.choices
         elif isinstance(self.choices, List):
             return decision < len(self.choices)
         else:
             raise Exception("Invalid type for choices")

     # Get a choice for a decision
     def get_choice(self, decision: Union[str, int]) -> Choice:
         assert self.verify(decision)
         return self.choices[decision]

     # Apply a decision to an input IRModule
     def apply(self, mod: IRModule, decision: Union[str, int]) -> IRModule:
         assert self.verify(decision)
         return self.choices[decision].func(mod)

     def __str__(self) -> str:
         msg = f"{self.name} (# of choices: {len(self.choices)})\n"
         if isinstance(self.choices, dict):
             for name, choice in self.choices.items():
                 msg += f"  - {name}: {choice}\n"
         elif isinstance(self.choices, List):
             for idx, choice in enumerate(self.choices):
                 msg += f"  - {idx}: {choice}\n"
         else:
             raise Exception("Invalid type for choices")
         return msg

 # Trace maintains a sequence of instructions and their decisions.
 # It maintains the input/output IRModule and its performance
 class Trace:
     def __init__(
                    self, 
                    in_mod: IRModule, 
                    trace: List[Tuple[Instruction, Union[str, int]]] = []
                  ):
         self.in_mod = in_mod
         self.trace = trace
         self.out_mod = self.apply(in_mod, trace)
         self.perf = None

     def verify(self):
         for (knob, decision) in self.trace:
             if not knob.verify(decision):
                 return False
         return True

     # Apply certain trace to input IRModule
     def apply(self, in_mod: IRModule, trace: Trace) -> IRModule:
         out_mod = copy.deepcopy(in_mod)
         for knob, decision in trace:
             if not knob.verify(decision):
                 raise Exception("Illegal decision in the trace")
             out_mod = knob.apply(in_mod, decision)
         self.perf = None
         return out_mod

     # Add a pair of intruction and its decision to the current trace
     def add(self, knob: Instruction, decision: Union[str, int]) -> None:
         self.out_mod = knob.apply(self.out_mod, decision)
         self.trace.append((knob, decision))
         self.perf = None

     def __str__(self) -> str:
         msg = f"Trace length: {len(self.trace)}\n"
         for idx, (knob, decision) in enumerate(self.trace):
             msg += f"[{idx+1}] {knob.name}: {decision}\n"
         return msg

# Helper functions
# Generate the search space for a given trace by using registered choices
# To reduce the search space, it may expand each choice in a probablistic manner
# A developer can introduce a smart search strategies like multi-armed bandit
def generate_candidates(inst, trace: Trace, ctx: PassContext, eval_passes: List[Pass] = None) -> List[Trace]:
   candidates = list()
   for decision in inst.choices.keys():
       choice = inst.choices[decision]
       # Generate new candidate when this condition satisfies
       if choice.constr:
           new_trace = copy.deepcopy(trace)
           new_trace.add(inst, decision)
           candidates.append(new_trace)
   # Expand candidates by using eval passes if available
   if eval_passes:
       candidates = consider_eval_passes(candidates, ctx, eval_passes)
   return candidates

# Expands traces generated by current tuning pass with its eval passes
def consider_eval_passes(
   seeds: List[Trace], ctx: PassContext, eval_passes: List[Pass] = None
) -> List[Trace]:
   candidates = list(seeds)
   num = len(candidates)
   for i in range(num):
       trace = candidates.pop(0)
       for eval_pass in eval_passes:
           # For heuristic pass, we create an know with single choice for tracking
           if isinstance(eval_pass, HeuristicPass):
               knob = Instruction(f"{eval_pass.name}", [Choice(eval_pass)])
               trace.add(knob, 0)
           # Tuning pass expands candidates by visiting its evaluation passes in dfs
           else:
               trace = eval_pass()(trace, ctx)

       candidates.append(trace)
   return candidates

# Evaluates each candidate with MetaSchedule Runner/Builder
# Its performance can be stored in MetaSchedule Database
 def evaluate(ctx, candidates: List[Trace], eval_config, database):
    # These targets will be retrieved from the ctx
    target_str, target_host, device_id = (
           ctx.config["target"],
           ctx.config["target_host"],
           ctx.config["device_id"],
    )
    target = tvm.target.Target(target_str)
    device = tvm.device(target_str, device_id)

    num_evals = 0
    # Evaluation
    for candidate in candidates:
    if candidate.perf is not None:
         continue
     num_evals += 1
     mod = candidate.out_mod
     # Evaluate candidates
     # Build candidate
     builder = LocalBuilder(f_build= ... )
     (builder_result,) = builder.build([BuilderInput(mod, target)])

     assert builder_result.artifact_path is not None
     assert builder_result.error_msg is None

     runner_input = RunnerInput(
         builder_result.artifact_path,
         target_str,
         [],  # ArgInfo
     )

     runner = LocalRunner(
         timeout_sec=100,
         evaluator_config=eval_config,
         f_run_evaluator=eval_func,
     )

     (runner_future,) = runner.run([runner_input])
     runner_result = runner_future.result()

     assert runner_result.error_msg is None
     perfs = []
     for result in runner_result.run_secs:
         if isinstance(result, tvm.tir.FloatImm):
             result = result.value
         assert isinstance(result, float)
         assert result >= 0.0
         perfs.append(result)

     # ...
     candidate.perf = tuple([np.mean(perfs), np.std(perfs)])

     if database is not None:

         workload = database.commit_workload(mod)
         record = TuningRecord(
             trace,
             perfs,
             workload,
             target,
             [],  
         )
         database.commit_tuning_record(record)

# Choose the best trace
def select_best_candidate(traces):
    best_perf, best_trace = sys.maxsize, None
    for candidate in traces:
        (avg, std) = candidate.perf
         # Select best one
         if best_perf > avg:
             best_perf = avg
             best_trace = candidate
    return best_trace

# Return trace wrapper if necessary
def get_trace(in_):
     if isinstance(in_, Trace):
         return in_
     if isinstance(in_, IRModule):
         return Trace(in_)
     elif isinstance(in_, Expr):
         return Trace(tvm.IRModule.from_expr(in_))
     #...
     else:
         raise Exception("Invalid input type for pass")

# Extracts matching subgraph for subgraph-level tuning
def extract_subgraph(mod, pattern):
    # ...

# [Optional] a cost model that estimates the performance of a trace
def query_cost_model(cost_model, trace:Trace)->float:
     assert 0, "Need to implement"

# [Optional] a prediction model that predicts the optimized IRModule
# This can be done by heuristic like AutoTVM 
# or data-driven approach like ApplyHistoryBest in MetaSchedule
def predict(mod: IRModule, ctx) -> IRModule:
     assert 0, "Need to implement"

Example

Setup

from TuningAPI import (
        Choice, 
        Trace, 
        Instruction, 
        generate_candidates, 
        consider_eval_passes,
        evaluate, 
        select_best_candidate
)

Simple switching decision

class TuningParallelConv2dPass(Pass):
     def __init__(self, eval_passes=[], required=[], database=None):
         super().__init__(
             "TuneCombineParallelConv2D",
             required=required,
         )
                 self.eval_passes=eval_passes,
         self.database=database

     def tune(self, trace, ctx):
         def apply(mod):
                         new_mod = InferType()(mod)
             new_mod = CombineParallelConv2D(min_num_branches=2)(new_mod)
             return new_mod

         def noapply(mod):
             return mod

         choices = {"On": Choice(apply), "Off": Choice(noapply)}
         # Tuning pass manages a set of transformation functions
         inst = Instruction("InstructionTuningParallelConv2D", choices)
         candidates = generate_candidates(inst, trace, ctx, self.eval_passes)
         evaluate(ctx, candidates, self.database)
         best_trace = select_best_candidate(candidates)
         return best_trace

     def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:  
         best_trace = self.tune(get_trace(mod), ctx)
         return best_trace.out_mod

Layout Transformation

class TuningLayoutPass(Pass):
     def __init__(self, eval_passes=[], required=[], database=None):
         super().__init__(
             "TuneLayout",
             required=required,
         )
         self.eval_passes=eval_passes
         self.database=database
         self.num_evals = 0

     def tune(self, trace, ctx):
        def convert_conv2d_NHWC(mod):
                   new_mod = ConvertLayout({"nn.conv2d": ["NHWC", ...]})(new_mod)
           return new_mod

        def convert_conv2d_NCHW(mod):
           new_mod = ConvertLayout({"nn.conv2d": ["NCHW", ...]})(new_mod)
           return new_mod

        def noapply(mod):
           return mod

        choices = {
           "convert_conv2d_NHWC": Choice(convert_conv2d_NHWC),
           "convert_conv2d_NCHW": Choice(convert_conv2d_NCHW),
           "NoApply": Choice(noapply),
         }
         inst = Instruction("InstructionTuningLayout", choices)
         candidates = generate_candidates(inst, trace, ctx, self.eval_passes)
         evaluate(ctx, candidates, self.database)
         best_trace = select_best_candidate(candidates)
         return best_trace

     def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
         best_trace = self.tune(get_traced(mod), ctx)
         return best_trace.out_mod

TASO

taso
class TASO(Pass):
     def __init__(self, eval_passes={}, required=[], database=None):
         super().__init__(
             "TASO",
             required=required,
         )
         self.eval_passes=eval_passes
         self.database=database
         self.num_evals = 0

     def tune(self, trace, ctx):
       q = PriorityQueue()
       q.push(trace.in_mod)
       best_trace, best_perf = None, 1e100
       while not q.empty():
          g = q.pop()
          choices = []
          for s in get_available_substitutions():
             for l in get_available_layouts(g, s):
                  choice = Choice(
                              f"{s.name}_{l.name}", 
                              get_rewriting_func(s,l)
                          )
                  choices.append(choice)
          inst = Instruction("tune_rewriting", choices)
          candidates = generate_candidates(inst, trace, ctx, self.eval_passes)
          evaluate(ctx, candidates, self.database)
          best_cand = select_best_candidate(candidates)
          if best_cand.perf < best_perf:
              best_trace, best_perf = best_cand best_cand.perf
          next_population = get_top_alpha(candidates, best_perf, alpha)
          q.push(new_population)

        return best_trace

     def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
         best_trace = self.tune(get_traced(mod), ctx)
         return best_trace.out_mod

Collage

Screen Shot 2022-04-21 at 4 40 49 PM
class Collage(Pass):
    def __init__(self, eval_passes={}, required=[], database=None):
          super().__init__(
             "Collage",
             required=required,
         )
      self.eval_passes=eval_passes
      self.database=database
      self.num_evals = 0

    def tune(self, trace, ctx):
        def func(mod):
            g = build_graph(trace.mod)
        q = FrontierQueue() # priority queue sorted by node depth
        q.push(g.get_root())
        while not q.empty():
            f = q.pop()
            expr = f.get_expr()
            choices = []
            new_frontiers = []
            for backend, pattern in get_available_backend():
               if pattern.match(expr):
                  choice = Choice(
                            f"{backend.name}", 
                            get_extract_and_annotate_func(pattern, backend)
                   )
                  choices.append(choice)
                  new_frontiers.append(get_new_frontiers(f, pattern))
            inst = Instruction("tune_backend", choices)
            new_trace = Trace(tvm.IRModule.from_expr(expr))
            candidates = generate_candidates(inst, new_trace)
                # You can manually expand candidates by using eval_pass for more control
                new_candidates = []
                for candidate in candidates:
                    backend_name = candidate.trace[-1][0].name
                    # Depending on the backend, apply different passes
                    eval_pass = self.eval_passes[backend_name]
                cands = consider_eval_passes([candidate], ctx, eval_passes)
                    new_candidates.extend(cands)
           evaluate(ctx, candidates, self.database)
           best_trace = select_best_candidate(candidates)
           update_placement(expr, best_trace)
           q.push(new_frontiers)
            return apply_best_placement(mod)
      inst = Instruction("Collage", [Choice("Subgraph-tuning": func)])
      return best_trace

    def transform_module(self, mod: Union[IRModule, Trace], ctx: PassContext) -> IRModule:
     best_trace = self.tune(get_traced(mod), ctx)
     return best_trace.out_mod

Pipeline customization

# 1. Apply single tuning pass
custom_pipeline = TuningParallelConv2dPass()     
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 2

custom_pipeline = TuningLayoutPass()     
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 3

# Heuristic pass won't increase the search space
custom_pipeline = TuningLayoutPass(eval_passes=[MyHeuristicPass()])     
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 3

# 2. Apply two tuning passes in sequential
# This is useful when we know two tuning passes are orthogonal to each other
# (we don't always want combinatorial search space with joint-optimization)
custom_pipeline = Sequential([TuningParallelConv2dPass(), TuningLayoutPass()])
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 2+3

custom_pipeline = Sequential([TuningLayoutPass(), TuningParallelConv2dPass()])
with PassContext(config=config):
    mod = custom_pipeline(mod)

assert TuningPass.total_num_evals == 3+2

# 3. Joint-optimization
custom_pipeline = TuningParallelConv2dPass(eval_passes=[TuningLayoutPass()])  
with PassContext(config=config):
    mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 2*3

custom_pipeline = TuningLayoutPass(eval_passes=[TuningParallelConv2dPass()])  
with PassContext(config=config):
    mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3*2

# Say we have a MockTuningPass with search space of 5
custom_pipeline = TuningLayoutPass(
     eval_passes=[TuningParallelConv2dPass(
                      eval_passes=[MockTuningPass()]
                  )]
)  
with PassContext(config=config):
    mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3*2*5

custom_pipeline = TuningLayoutPass(
     eval_passes=[TuningParallelConv2dPass(), MockTuningPass()]
)  
with PassContext(config=config):
    mod = custom_pipeline(mod)
assert TuningPass.total_num_evals == 3*(2+5)

C++ implementation

namespace transform {
// Since both heuristic and tuning methods live in the same file,
// code sharing is natural

Pass MockHeurstic() {
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        ConstantFolder folder(m);
        return Downcast<Function>(folder(f));
      };
  return CreateFunctionPass(pass_func, 0, "MockHeuristic", {});
}

TVM_REGISTER_GLOBAL("relax.transform.MockHeuristic").set_body_typed(MockHeuristic);

Pass MockTune(){
  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> tune_pass_func =
      [=](Function f, IRModule m, PassContext pc) {
        // Implement your tuner
        new_f = MockTuner().tune(f);
        return Downcast<Function>(new_f);
      };
  return CreateFunctionPass(tune_pass_func, 0, "MockTuner", required={"InferType"});
}
TVM_REGISTER_GLOBAL("relax.transform.MockTune").set_body_typed(MockTune);

Comparison with Prior Work

Opportunity

Limitations

Discussion