Closed kevinthesun closed 5 years ago
@eqy @merrymercy please comment
Thanks for opening this RFC, graph level optimization is an important step in pushing performance in cases were we have to make inter-layer decisions (such as data layout in this case).
Looking at the code, I think that this is a good opportunity to either refine or more clearly lay out what we should have in our autotvm/general tuning API. Having a good API that is used healthily ensures that we keep things maintainable and gives us confidence that things are extensible.
To that end, it may be a good time to discuss how we should organize graph-level and layer-level tuning. We currently already have cost-model tuners that operate on search space; is it natural to extend the notion of a search space to cover possible graphs as well?
Taking a look at some of the odds and ends of the PR, here's some comments (in no good order):
One main question: Currently there seem to be no calls into existing autotvm code; can we not define new tasks that the current autotvm can tune to leverage the existing infrastructure?
Currently graph tuner doesn't use any autotvm code, we should definitely reuse the AutoTVM system to tune kernel. Actually graph tuner is designed to be a standalone module and doesn't couple with any specific tensor tuner. Things might need to be changed to make this happen:
For the search space question, currently it should be enough for AVX512/AVX2.
It is a great step. For graph level layout planning, current DP solution is nice. For operator level tuning, we can reuse measurement/tuner infrastructure in autotvm.
It seems that we developed our custom tuner systems at a same time, and there are a lot of things to do for the merge. But keeping an unified infrastructure ensures the maintainability. I can give some guidance on porting the executor to autotvm style.
Conv2dAVXExecutor.workload_execute
[{"schedule": AVXConvCommonFwd(1, 4, 2, True), "time": 0.04}],
in your code. AutoTVM can serialize them to json and use them as the log file (see tvm/autotvm/record.py)._get_schedule_conv
in current x86 topi. In AutoTVM, topi.nn.conv2d
and topi.generic.schedule_conv2d
will construct a workload and use this workload to query corresponding ConfigEntity from the dispatch context.Previously, schedule function topi.generic.schedule_conv2d
need to reconstruct the workload from dataflow by using _get_workload
. Now, in compute function topi.nn.conv2d
, we can attach a tuple to the compute op https://github.com/dmlc/tvm/blob/54a115ef14fb6dabbf6ea8eb9e6dd85846030c72/topi/python/topi/arm_cpu/conv2d.py#L148. Then in schedule function, we just fetch the workload by op.attrs['workload']
topi/x86
into autotvm style.
cfg.define_split
, cfg.define_knob
to define your spaceGridSearchTuner
for exhaustive search. For layout, you can reuse https://github.com/dmlc/tvm/blob/54a115ef14fb6dabbf6ea8eb9e6dd85846030c72/python/tvm/autotvm/measure/measure_methods.py#L179 for measurement. You can create a LayoutTuner, and feed in customized MeasureInput.tvm.target.avx()
), tvm will download them. I am not sure whether you like this style.@merrymercy Thanks for suggestion! I'll make changes accordingly.
@kevinthesun do you have previously collected data on the best graph level (data layout) choices for some different C5 instance types (e.g., xlarge, 2xlarge, 4xlarge, 9xlarge) on ResNet-50?
We are planning on doing some experiments with autotvm on EC2 and those would be very valuable for us.
@eqy https://github.com/kevinthesun/intel-benchmark This repo contains link to pre-tuned best schedules for several imagenet models. These schedules are searched on c5.9xlarge, but can be directly applied to other types of c5. Convolution schedules are stored in the ascending order of node index.
@merrymercy I checked arm_cpu conv2d and have several questions for implementation detail:
see point 4
Exactly, you can create many dispatch contexts. In autotvm, during tuning, we use ApplyConfig
to apply the config for tuning; during compilation, we use ApplyHistoryBest
. For your case, you need to change the dispatch context used during compilation.
No. You can call create_measure_batch
. It will return this function for you. You can see the usage of create_measure_batch
here. https://github.com/dmlc/tvm/blob/7cb85d81968cd69576d923852d812590b93cc26d/python/tvm/autotvm/tuner/tuner.py#L87
You can use normal extract_from_graph
to get conv2d tasks. Then transform them into conv2d_NCHWc tasks.
The current conv2d task is defined at https://github.com/dmlc/tvm/blob/7cb85d81968cd69576d923852d812590b93cc26d/python/tvm/autotvm/task/nnvm_integration.py#L99-L105
You can define the task for conv2d_NCHWc as follows
@register("topi_nn_conv2d_NCHWc")
def _topi_nn_conv2d(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
args = deserialize_args(args)
A, W = args[:2]
# get config here
cfg = autotvm.get_config()
cfg.define_knob('tile_c', [1, 2, 4, 8, 16])
# change shape with the value in config
VC = cfg['tile_c'].val
raw_shape = get_const_tuple(A.shape)
new_shape = (raw_shape[0], raw_shape[1] // VC, raw_shape[2], raw_shape[3], VC)
args[0] = tvm.placeholder(new_shape, A.dtype)
C = topi.nn.conv2d_NCHWc(*args, **kwargs)
s = topi.generic.schedule_conv2d_NCHWc([C])
return s, [A, W, C]
@merrymercy I can tune conv2d_NCHWc with autotvm now. I have one issue for logging to file. It returned JSON not serializable error: TypeError: Tensor(shape=[1, 3, 4, 4], op.name=data) is not JSON serializable.
@kevinthesun Can you check that the format of MeasureInput
and MeasureResult
that you use don't contain any non-serializable data structures? e.g., we only use NamedTuples and Lists
TVM Tensor will be serialized by https://github.com/dmlc/tvm/blob/b11f2a0495541cb348ae89093fd233d78eefec6e/python/tvm/autotvm/task/nnvm_integration.py#L17-L23 or https://github.com/dmlc/tvm/blob/b11f2a0495541cb348ae89093fd233d78eefec6e/python/tvm/autotvm/task/task.py#L185-L192
Can you check whether there is something missing?
Found the issue, I need to call serialize_arg before creating conv2d_NCHWc tasks since I was not using extract_from_graph.
If I want to get MeasureResult from a tuner, do I need to create a callback, or there is API existing to do this?
There is no existing API.
Can't we pass ConfigSpace into tvm.compute? I got ValueError: don't know how to convert type <class 'tvm.autotvm.task.space.ConfigSpace'> to node.
@eqy @merrymercy I got an issue using autotvm to tune conv2d_NCHWc on intel cpu. The benchmark results got from autotvm mismatch with the results that I use the same records to run with pure tvm(use tvm.build and feed in data directly). This usually happens when average exec time is under 0.1 ms, and autotvm gets much shorter exec time comparing to actual result, such as 0.097808883 ms vs 0.19609940052 ms.
I use the following settings:
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(n_parallel=1),
runner=autotvm.LocalRunner(number=run_times, min_repeat_ms=1500, cooldown_interval=2))
I use all cpu cores and no parallel jobs. I set min_repeat_ms to be 1500 so that the execution number is sufficient.
Did you use time_evaluator
for your "pure tvm"?
Autotvm uses time_evaluator, which will skip the first warm up run https://github.com/dmlc/tvm/blob/0c523787297039ce00b320c1d32e022e61e97ac2/python/tvm/autotvm/measure/measure_methods.py#L456-L458
I tried time_evaluator and got the similar results. I also apply the results of autotvm to graph tuner. The result doesn't match the records of autotvm. For example, for resnet18 the total exec time of conv2d is around 1.6 ms given record of autotvm. The actual exec time is around 3.6 ms, which corresponds "pure tvm" exec time.
I wonder if this may have something to do with whether layout transformation time is included? Don’t know the details here so that is just a guess.
Eddie On Thu, Aug 30, 2018 at 12:13 PM Yao Wang notifications@github.com wrote:
I tried time_evaluator and got the similar results. I also apply the results of autotvm to graph tuner. The result doesn't match the records of autotvm. For example, for resnet18 the total exec time of conv2d is around 1.6 ms given record of autotvm. The actual exec time is around 3.6 ms, which corresponds "pure tvm" exec time.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/dmlc/tvm/issues/1585#issuecomment-417434622, or mute the thread https://github.com/notifications/unsubscribe-auth/ACIsgIzXV-W9mM3V-eHlIllb16hEXRgQks5uWDlhgaJpZM4V45-3 .
Can you try to use repeat=1
in LocalRunner
?
Or can you give me some scripts that I can verify
@eqy I looked at the fused graph, there are only two layout transform, one at the beginning and one at the end, which is expected, but the e2e performance is not good. @merrymercy This is graph tuner branch: https://github.com/kevinthesun/tvm/tree/GraphTuner Major changes to use autotvm is under x86/conv2d.py. This is the script I use to benchmark:
import logging
import sys
import time
import numpy as np
import nnvm
import tvm
import topi
from tvm import autotvm
from tvm.autotvm.task import register, get_config
from tvm.autotvm.task.nnvm_integration import deserialize_args
from tvm.autotvm.util import get_const_tuple
from nnvm import symbol as sym
from tvm.contrib import graph_runtime
from nnvm.testing.utils import create_workload
from mxnet.gluon.model_zoo.vision import get_model
if __name__ == "__main__":
run_times = 10
model = "resnet18_v1"
image_shape = (3, 299, 299) if "inception" in model else (3, 224, 224)
dshape = (1,) + image_shape
dtype = "float32"
target = 'llvm -mcpu=skylake-avx512'
block = get_model(model, pretrained=True)
net, params = nnvm.frontend.from_mxnet(block)
tasks = autotvm.task.extract_from_graph(net, target=target, shape={'data': dshape}, dtype=dtype, symbols=(sym.conv2d,))
logging.getLogger('autotvm').setLevel(logging.DEBUG)
logging.getLogger('autotvm').addHandler(logging.StreamHandler(sys.stdout))
measure_option = autotvm.measure_option(
builder=autotvm.LocalBuilder(n_parallel=1),
runner=autotvm.LocalRunner(number=run_times, min_repeat_ms=1500))
args_set = set()
for i, task in enumerate(tasks):
data, kernel, strides, padding, layout, dtype = task.args
kernel_size = (kernel[1][2], kernel[1][3])
data_plc = tvm.placeholder(data[1], name="data")
kernel_plc = tvm.placeholder(kernel[1], name="kernel")
args = [data_plc, kernel_plc, data[1][1], kernel_size, strides, padding, layout, layout, dtype]
args = autotvm.task.nnvm_integration.serialize_args(args)
#print(args)
if args in args_set:
continue
args_set.add(args)
task = autotvm.task.create("topi_x86_conv2d_NCHWc", args=args, target=target)
task.workload = topi.x86.conv2d.conv_NCHWc_arg_to_workload(data_plc, kernel_plc, kernel_size, strides, padding, layout, dtype)
tuner = autotvm.tuner.GridSearchTuner(task)
tuner.tune(n_trial=len(task.config_space),
measure_option=measure_option,
callbacks=[autotvm.callback.log_to_file('resnet.log')])
I run it on AWS c5.9xlarge with:
TVM_NUM_THREADS=18 nohup python test_autotvm.py
using all 18 cores to sequentially benchmarking all jobs. I'll try repeat=1.
These are optimal records I got for resnet18:
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 3, 224, 224], "float32"], ["TENSOR", [64, 3, 7, 7], "float32"], 3, [7, 7], [2, 2], [3, 3], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 3, 224, 224, "float32"], [64, 3, 7, 7, "float32"], [2, 2], [3, 3], "NCHW", "float32"], {"i": 41, "c": null, "e": [["tile_ic", "sp", [1, 3]], ["tile_oc", "sp", [1, 64]], ["tile_ow", "sp", [28, 4]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[9.7808883e-05], 0, 1.138369083404541, 1535505799.372973], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [64, 64, 3, 3], "float32"], 64, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 64, 56, 56, "float32"], [64, 64, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 139, "c": null, "e": [["tile_ic", "sp", [1, 64]], ["tile_oc", "sp", [2, 32]], ["tile_ow", "sp", [14, 4]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.0001027261465], 0, 1.166485071182251, 1535509104.953128], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [64, 64, 3, 3], "float32"], 64, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 64, 56, 56, "float32"], [64, 64, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 96, "c": null, "e": [["tile_ic", "sp", [2, 32]], ["tile_oc", "sp", [1, 64]], ["tile_ow", "sp", [28, 2]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.00010640119299999999], 0, 1.223567008972168, 1535508795.96213], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [64, 64, 3, 3], "float32"], 64, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 64, 56, 56, "float32"], [64, 64, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 139, "c": null, "e": [["tile_ic", "sp", [1, 64]], ["tile_oc", "sp", [2, 32]], ["tile_ow", "sp", [14, 4]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.0001027261465], 0, 1.166485071182251, 1535509104.953128], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [64, 64, 3, 3], "float32"], 64, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 64, 56, 56, "float32"], [64, 64, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 96, "c": null, "e": [["tile_ic", "sp", [2, 32]], ["tile_oc", "sp", [1, 64]], ["tile_ow", "sp", [28, 2]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.00010640119299999999], 0, 1.223567008972168, 1535508795.96213], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [128, 64, 1, 1], "float32"], 64, [1, 1], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 64, 56, 56, "float32"], [128, 64, 1, 1, "float32"], [2, 2], [0, 0], "NCHW", "float32"], {"i": 209, "c": null, "e": [["tile_ic", "sp", [1, 64]], ["tile_oc", "sp", [4, 32]], ["tile_ow", "sp", [4, 7]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[8.558321500000001e-06], 0, 0.20387506484985352, 1535523457.846529], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [128, 64, 3, 3], "float32"], 64, [3, 3], [2, 2], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 64, 56, 56, "float32"], [128, 64, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"i": 447, "c": null, "e": [["tile_ic", "sp", [1, 64]], ["tile_oc", "sp", [1, 128]], ["tile_ow", "sp", [14, 2]], ["unroll_kw", "ot", false]], "t": ""}], "r": [[6.20716565e-05], 0, 0.7738749980926514, 1535516294.860449], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [128, 128, 3, 3], "float32"], 128, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 128, 28, 28, "float32"], [128, 128, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 175, "c": null, "e": [["tile_ic", "sp", [1, 128]], ["tile_oc", "sp", [4, 32]], ["tile_ow", "sp", [7, 4]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.00010006059279999999], 0, 1.1714038848876953, 1535519139.417935], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [128, 128, 3, 3], "float32"], 128, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 128, 28, 28, "float32"], [128, 128, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 509, "c": null, "e": [["tile_ic", "sp", [4, 32]], ["tile_oc", "sp", [1, 128]], ["tile_ow", "sp", [14, 2]], ["unroll_kw", "ot", false]], "t": ""}], "r": [[0.0001036367829], 0, 1.2258639335632324, 1535521977.144608], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [128, 128, 3, 3], "float32"], 128, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 128, 28, 28, "float32"], [128, 128, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 175, "c": null, "e": [["tile_ic", "sp", [1, 128]], ["tile_oc", "sp", [4, 32]], ["tile_ow", "sp", [7, 4]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.00010006059279999999], 0, 1.1714038848876953, 1535519139.417935], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [256, 128, 1, 1], "float32"], 128, [1, 1], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 128, 28, 28, "float32"], [256, 128, 1, 1, "float32"], [2, 2], [0, 0], "NCHW", "float32"], {"i": 133, "c": null, "e": [["tile_ic", "sp", [4, 32]], ["tile_oc", "sp", [2, 128]], ["tile_ow", "sp", [7, 2]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[7.8310528e-06], 0, 0.22060489654541016, 1535532290.707203], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 128, 28, 28], "float32"], ["TENSOR", [256, 128, 3, 3], "float32"], 128, [3, 3], [2, 2], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 128, 28, 28, "float32"], [256, 128, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"i": 133, "c": null, "e": [["tile_ic", "sp", [4, 32]], ["tile_oc", "sp", [2, 128]], ["tile_ow", "sp", [7, 2]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[5.92470041e-05], 0, 0.7696590423583984, 1535524742.028239], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [256, 256, 3, 3], "float32"], 256, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 14, 14, "float32"], [256, 256, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 475, "c": null, "e": [["tile_ic", "sp", [2, 128]], ["tile_oc", "sp", [2, 128]], ["tile_ow", "sp", [7, 2]], ["unroll_kw", "ot", false]], "t": ""}], "r": [[0.0001025999385], 0, 1.2034080028533936, 1535531394.87406], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [256, 256, 3, 3], "float32"], 256, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 14, 14, "float32"], [256, 256, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 475, "c": null, "e": [["tile_ic", "sp", [2, 128]], ["tile_oc", "sp", [2, 128]], ["tile_ow", "sp", [7, 2]], ["unroll_kw", "ot", false]], "t": ""}], "r": [[0.0001025999385], 0, 1.2034080028533936, 1535531394.87406], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [256, 256, 3, 3], "float32"], 256, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 14, 14, "float32"], [256, 256, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 475, "c": null, "e": [["tile_ic", "sp", [2, 128]], ["tile_oc", "sp", [2, 128]], ["tile_ow", "sp", [7, 2]], ["unroll_kw", "ot", false]], "t": ""}], "r": [[0.0001025999385], 0, 1.2034080028533936, 1535531394.87406], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [512, 256, 1, 1], "float32"], 256, [1, 1], [2, 2], [0, 0], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 14, 14, "float32"], [512, 256, 1, 1, "float32"], [2, 2], [0, 0], "NCHW", "float32"], {"i": 142, "c": null, "e": [["tile_ic", "sp", [2, 128]], ["tile_oc", "sp", [16, 32]], ["tile_ow", "sp", [1, 7]], ["tile_oh", "ot", 1]], "t": ""}], "r": [[7.6947802e-06], 0, 0.18447494506835938, 1535539846.239089], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [512, 256, 3, 3], "float32"], 256, [3, 3], [2, 2], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 14, 14, "float32"], [512, 256, 3, 3, "float32"], [2, 2], [1, 1], "NCHW", "float32"], {"i": 133, "c": null, "e": [["tile_ic", "sp", [2, 128]], ["tile_oc", "sp", [32, 16]], ["tile_ow", "sp", [1, 7]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[7.73579001e-05], 0, 0.9138240814208984, 1535533580.560437], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [512, 512, 3, 3], "float32"], 512, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 512, 7, 7, "float32"], [512, 512, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 154, "c": null, "e": [["tile_ic", "sp", [32, 16]], ["tile_oc", "sp", [16, 32]], ["tile_ow", "sp", [1, 7]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.0001014149207], 0, 1.1709089279174805, 1535537031.453512], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [512, 512, 3, 3], "float32"], 512, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 512, 7, 7, "float32"], [512, 512, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 155, "c": null, "e": [["tile_ic", "sp", [16, 32]], ["tile_oc", "sp", [16, 32]], ["tile_ow", "sp", [1, 7]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.0001060334912], 0, 1.224541187286377, 1535537032.803111], "v": 0.1}
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 512, 7, 7], "float32"], ["TENSOR", [512, 512, 3, 3], "float32"], 512, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 512, 7, 7, "float32"], [512, 512, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 155, "c": null, "e": [["tile_ic", "sp", [16, 32]], ["tile_oc", "sp", [16, 32]], ["tile_ow", "sp", [1, 7]], ["unroll_kw", "ot", true]], "t": ""}], "r": [[0.0001060334912], 0, 1.224541187286377, 1535537032.803111], "v": 0.1}
Most of them can't match with pure tvm execution time. I use this script to verify:
import logging
import sys
import time
import numpy as np
import nnvm
import tvm
import topi
from tvm import autotvm
from tvm.autotvm.task import register, get_config
from tvm.autotvm.task.nnvm_integration import deserialize_args
from tvm.autotvm.record import load_from_file
from tvm.autotvm.util import get_const_tuple
from nnvm import symbol as sym
from tvm.contrib import graph_runtime
from nnvm.testing.utils import create_workload
from mxnet.gluon.model_zoo.vision import get_model
def helper(cfg, *args):
data, kernel = args[:2]
kernel_size = args[3]
strides = args[4]
padding = args[5]
layout = args[6]
kh, kw = kernel_size if isinstance(kernel_size, (tuple, list)
) else (kernel_size, kernel_size)
is_kernel_1x1 = kh == 1 and kw == 1
raw_data_shape = get_const_tuple(data[1])
raw_kernel_shape = get_const_tuple(kernel[1])
# change shape with the value in config
ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1])
new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
raw_data_shape[2], raw_data_shape[3], ic_bn)
data_layout = "NCHW%dc" % ic_bn
out_layout = "NCHW%dc" % oc_bn
if is_kernel_1x1:
new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
ic_bn, oc_bn, raw_kernel_shape[2], raw_kernel_shape[3])
else:
new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
data = tvm.placeholder(new_data_shape, args[-1])
kernel = tvm.placeholder(new_kernel_shape, args[-1])
C = topi.x86.conv2d._declaration_conv_NCHWc(cfg, data, kernel, args[2], args[3], args[4], args[5],
data_layout, out_layout, args[-1])
s = topi.x86.conv2d._schedule_conv2d_NCHWc(cfg, args[2], args[3], args[4], args[5],
data_layout, out_layout, [C])
return s, [data, kernel, C]
target = "llvm -mcpu=skylake-avx512"
ctx = tvm.cpu()
a = load_from_file("resnet18_opt.log")
t = 0
for i, o in a:
args = i.task.args
print(args)
cfg = i.config
print(cfg)
print("Autotvm time: %f" % (o.costs[0] * 1000))
s, ts = helper(cfg, *args)
data, kernel, out = ts
f = tvm.build(s, [data, kernel, out], target)
d = tvm.nd.array(np.random.uniform(size=get_const_tuple(data.shape)).astype("float32"), ctx)
k = tvm.nd.array(np.random.uniform(size=get_const_tuple(kernel.shape)).astype("float32"), ctx)
o = tvm.nd.array(np.zeros(get_const_tuple(out.shape)).astype("float32"), ctx)
f_t = f.time_evaluator(
f.entry_name, ctx, number=10000, repeat=1)
cost = f_t(d, k, o).results[0]
print("Actual time: %f" % (cost * 1000))
t += cost * 1000
print("\n")
print(t)
The default repeat for LocalRunner
is 3, and it will remove the largest and smallest result
https://github.com/dmlc/tvm/blob/0c523787297039ce00b320c1d32e022e61e97ac2/python/tvm/autotvm/measure/measure_methods.py#L465-L469
One quick verification you can try is to change your time evaluator part in your check script with the following code. Now the check script uses the exactly same setting in LocalRunner
f = tvm.build(s, [data, kernel, out], target)
d = tvm.nd.array(np.random.uniform(size=get_const_tuple(data.shape)).astype("float32"), ctx)
k = tvm.nd.array(np.random.uniform(size=get_const_tuple(kernel.shape)).astype("float32"), ctx)
output = tvm.nd.array(np.zeros(get_const_tuple(out.shape)).astype("float32"), ctx)
number = int(1.5 / o.costs[0]) # adjust number
f_t = f.time_evaluator(f.entry_name, ctx, number=number, repeat=3)
cost = list(f_t(d, k, output).results)
cost.sort()
cost = cost[1] # remove largest and smallest.
print("Actual time: %f" % (cost * 1000))
I run your code on a 16-core AMD ThreadRipper. The autotvm results match your check script for kernels that run < 0.1ms. (without changing repeat setting) I have sent the limit increase request on AWS. Once I have access to c5.9xlarge instance, I can verify this on c5.9xlarge instance.
I tried the time evaluator settings autotvm used and still got similar results as before. I randomly pick some records. Even when exec time > 0.1, sometimes the result is not accurate, such as 0.17 measured VS 0.3 real.
This workload and cfg has 4 times gap between measurement and real value: (('TENSOR', (1, 256, 14, 14), 'float32'), ('TENSOR', (256, 256, 3, 3), 'float32'), 256, (3, 3), (1, 1), (1, 1), 'NCHW', 'NCHW', 'float32') [('tile_ic', [2, 128]), ('tile_oc', [2, 128]), ('tile_ow', [7, 2]), ('unroll_kw', False)],,None,475 Autotvm time: 0.102600 Actual time: 0.494045
A phenomenon for workload (('TENSOR', (1, 256, 14, 14), 'float32'), ('TENSOR', (256, 256, 3, 3), 'float32'), 256, (3, 3), (1, 1), (1, 1), 'NCHW', 'NCHW', 'float32') is that when 'tile_oc' is [2, 128] and 'tile_ow' is [7, 2], the gap between measurement and actual result is quite large. You can verify this by skipping other tasks except when i == 6.
Just for a sanity check, I wonder if we can quickly compare the dumped llvm IR source between the two versions.
I confirmed one thing. In autotvm, I use tvm.nd.empty
to create input arrays. In this way we don't copy the array (several megabytes) to remote rpc devices.
https://github.com/dmlc/tvm/blob/4c4a8ea47b88677d89468e97e584cdad64b5b88e/python/tvm/autotvm/measure/measure_methods.py#L463
But it results in the inaccurate measurement for workload workload (('TENSOR', (1, 256, 14, 14), 'float32'), ('TENSOR', (256, 256, 3, 3), 'float32'), 256, (3, 3), (1, 1), (1, 1), 'NCHW', 'NCHW', 'float32')
create file check_empty.log
and run python3 check.py
twice.
# one row copied from your log
{"i": ["llvm -mcpu=skylake-avx512", "topi_x86_conv2d_NCHWc", [["TENSOR", [1, 256, 14, 14], "float32"], ["TENSOR", [256, 256, 3, 3], "float32"], 256, [3, 3], [1, 1], [1, 1], "NCHW", "NCHW", "float32"], {}, ["conv2d_NCHWc", [1, 256, 14, 14, "float32"], [256, 256, 3, 3, "float32"], [1, 1], [1, 1], "NCHW", "float32"], {"i": 475, "c": null, "e": [["tile_ic", "sp", [2, 128]], ["tile_oc", "sp", [2, 128]], ["tile_ow", "sp", [7, 2]], ["unroll_kw", "ot", false]], "t": ""}], "r": [[0.0001025999385], 0, 1.2034080028533936, 1535531394.87406], "v": 0.1}
import logging
import sys
import time
import numpy as np
import nnvm
import tvm
import topi
from tvm import autotvm
from tvm.autotvm.task import register, get_config
from tvm.autotvm.task.nnvm_integration import deserialize_args
from tvm.autotvm.record import load_from_file
from tvm.autotvm.util import get_const_tuple
from nnvm import symbol as sym
from tvm.contrib import graph_runtime
from nnvm.testing.utils import create_workload
from mxnet.gluon.model_zoo.vision import get_model
def helper(cfg, *args):
data, kernel = args[:2]
kernel_size = args[3]
strides = args[4]
padding = args[5]
layout = args[6]
kh, kw = kernel_size if isinstance(kernel_size, (tuple, list)
) else (kernel_size, kernel_size)
is_kernel_1x1 = kh == 1 and kw == 1
raw_data_shape = get_const_tuple(data[1])
raw_kernel_shape = get_const_tuple(kernel[1])
# change shape with the value in config
ic_bn, oc_bn, ow_bn = (cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1],
cfg["tile_ow"].size[-1])
new_data_shape = (raw_data_shape[0], raw_data_shape[1] // ic_bn,
raw_data_shape[2], raw_data_shape[3], ic_bn)
data_layout = "NCHW%dc" % ic_bn
out_layout = "NCHW%dc" % oc_bn
if is_kernel_1x1:
new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
ic_bn, oc_bn, raw_kernel_shape[2], raw_kernel_shape[3])
else:
new_kernel_shape = (raw_kernel_shape[0] // oc_bn, raw_kernel_shape[1] // ic_bn,
raw_kernel_shape[2], raw_kernel_shape[3], ic_bn, oc_bn)
data = tvm.placeholder(new_data_shape, args[-1])
kernel = tvm.placeholder(new_kernel_shape, args[-1])
C = topi.x86.conv2d._declaration_conv_NCHWc(cfg, data, kernel, args[2], args[3], args[4], args[5],
data_layout, out_layout, args[-1])
s = topi.x86.conv2d._schedule_conv2d_NCHWc(cfg, args[2], args[3], args[4], args[5],
data_layout, out_layout, [C])
return s, [data, kernel, C]
ctx = tvm.cpu()
target = "llvm -mcpu=skylake-avx512"
a = load_from_file("check_empty.log")
t = 0
for i, o in a:
if o.error_no != 0:
continue
args = i.task.args
print(args)
cfg = i.config
print("Autotvm time: %f" % (o.costs[0] * 1000))
s, ts = helper(cfg, *args)
data, kernel, out = ts
f = tvm.build(s, [data, kernel, out], target)
number = int(1.5 / o.costs[0] * 1.1) # adjust number
f_t = f.time_evaluator(f.entry_name, ctx, number=number, repeat=3)
# measure using empty tvm array
d_empty = tvm.nd.empty(get_const_tuple(data.shape), dtype="float32", ctx=ctx)
k_empty = tvm.nd.empty(get_const_tuple(kernel.shape), dtype="float32", ctx=ctx)
output_empty = tvm.nd.empty(get_const_tuple(out.shape), dtype="float32", ctx=ctx)
cost = f_t(d_empty, k_empty, output_empty).results[-1]
print("Empty time: %f" % (cost * 1000))
# measure using random tvm array
d = tvm.nd.array(np.random.uniform(size=get_const_tuple(data.shape)).astype("float32"), ctx)
k = tvm.nd.array(np.random.uniform(size=get_const_tuple(kernel.shape)).astype("float32"), ctx)
output = tvm.nd.array(np.zeros(get_const_tuple(out.shape)).astype("float32"), ctx)
cost = f_t(d, k, output).results[-1]
print("Actual time: %f\n" % (cost * 1000))
t += cost * 1000
print(t)
Autotvm time: 0.102600
Empty time: 0.104760
Actual time: 0.514933
0.5149327384491014
Thank you for identifying this! Do we have any solution? Like creating array on target device?
Can we just try to change the default allocation (in autotvm) to random uniform? My guess is that the issue is due to NaNs changing the timing.
Copying arrays to remove devices will make tuning wifi devices very slow. So we should make sure the copy happens on the remote device locally.
One quick solution for python servers is registering a python packed function, which uses numpy to randomly initialize a tvm ndarray. Then we get this function from RPC server so the copy happens locally.
# in ndarray.py
import numpy as np
from ._ffi.function import register_func
@register_func("tvm.nd.random_uniform")
def random_uniform(size, dtype, target):
size = [int(x) for x in size.split()]
return array(np.random.uniform(size=size).astype(dtype), context(target))
# in autotvm/measure/measure_methods.py, replace L463
random_uniform = remote.get_function('tvm.nd.random_uniform')
args = [random_uniform(" ".join([str(d) for d in x[0]]), x[1], str(measure_input.target))
for x in build_result.arg_info]
But this doesn't work for java runtime (android). We have to add c++ api in tvm runtime. There are some random APIs in tvm.contrib.random. But they are not enabled by default and only support float32 on cpu.
@merrymercy I tried this solution but got error: TVMError: [20:54:16] /home/ubuntu/tvm/src/runtime/rpc/rpc_session.cc:427: RPC cannot handle type NodeHandle
Sorry, we cannot call remote function with tuple or list. I updated my code in the above comment to use string
I tried quick solution and is able to reproduce previous collected optimal results now.
@kevinthesun @merrymercy in the interest of more recent changes of int8 schedule that might also benefit from AutoTVM x86 port, do you think if it makes sense to break the PR into two part and bring in AutoTVM x86 port in first? If so, please open a separate issue to track it
@tqchen I open an issue to track x86 AutoTVM related tasks.
I fond a related paper on graph layout tuning from CGO 2018 https://arxiv.org/pdf/1710.01079.pdf I do not believe their benchmark results in the paper. (for example. simple im2col matches or outperforms mkldnn). One thing we can learn is that they leverage a off-the-shelf PBQP solver to solve the programming problem. Because this problem seems to be a classical programming problem, I think we can also leverage some existing solvers. This can help the current "too large space" case.
I understand the motivation of this RFC and the paper @merrymercy linked. But I have the following questions regarding this RFC. I appreciate if somebody could answer them.
Isn't it the case that keeping the data layout in NCHWc as much as possible, and only insert layout transform when necessary, is always the best? This is what opt-level = 3 does for x86 backend. Other operators, such as max pooling, can operate directly on NCHWc layout.
The CGO 2018 paper is based on an assumption that there is data layout transform happening between each pair of layers. I think they need to do data layout transform because they have many variant of convolution algorithms each requiring different input layouts (otherwise edge cost = 0 for all edges). But TVM x86 backend has only direct algorithm at the moment. So I don't quite understand why Graph level auto-tuning in this RFC brings such big improvements over 'TVM with default schedules'. I think comparison between AutoTVM alone vs AutoTVM + Graph level tuning would clarify the benefit of Graph tuner.
One last point: @kevinthesun Is your mxnet built with CUDA off during cmake? Otherwise, elemwise ops are not parallelized with openmp (see here ) and MXNet results would be way worse than it should be.
@masahi
I think the point of introducing graph level tuning (e.g., for data layout) is similar to the original motivation with AutoTVM. In the past many of our schedule configurations and currently our choices of data layout are only the result of handcrafted heuristics. If there is a chance that we are leaving some efficiency on the table, then I think it is worthwhile to pursue automated approaches that leave fewer stones unturned. Having flexibility when we eventually get models that break our assumptions about how model architectures and data layout transformations is another bonus.
My understanding is that our algorithms (i.e., the best schedule configurations for each) are very shape dependent. So unless every layer has identical input/output/weight shapes, there is a chance that a different data layout will improve performance. The situation becomes even more complicated when we need to support both AVX-2 and AVX-512 CPUs, where the best choice of layout may be different depending on the vector width of the CPU.
I have some some early experiments with AutoTVM optimizing NCHWc schedules under a few hand-tuned hardcoded layout configurations. The results are OK, but my understanding is that you can only get so far without considering changes to data layout. But to reiterate I think manually defining heuristics for data layout will be brittle and unmaintainable.
If we introduce support for graph level tuning, this introduces benefits beyond data layout tuning; we can begin to consider joint architecture-schedule tuning and other techniques that are starting to become popular.
@eqy thanks, then I'm interested to know if there is an instance where the best NCHW schedule beats the best NCHWc schedule. My assumption is that NCHWc layout is always better for direct and winograd algorithm on x86. I care a lot about TVM performance on x86, so I'm very happy if the assumption was wrong. The difference between the best AVX2 schedule and the best AVX-512 schedule would also be interesting.
I did realize that when I think of NCHWc layout, I always have one specific layout, such as NCHW8c, in mind. So to me "data layout transform" always meant NCHW <-> NCHW8c conversion. But there are also NCHW16c, NCHW32c, etc... so there are lot more combination of possible conversion ( NCHW8c <-> NCHW16c, etc). The optimal inner channel dimension can be different for each layer, so it might be better to introduce data layout conversion of different inner channel dimension. The graph level auto-tuning in this RFC automates this decision. Is this understanding correct? @eqy
@masahi Your are right. Actually the "default" schedule for x86 conv2d is a simple heuristic method which always choose channel factor to be 16 if possible, to minimize the number of layout transformation needed. Graph tuner is an automated way to make these decisions related to data layouts.
The data layout NCHWc is suitable for graph tuner to do such searching. However, if a certain algorithm requires different input and output data layouts, say NCHW vs NCHWc, then these data layouts can not be eliminated. In this case graph tuner wouldn't help much. Another scenario which limits graph tuner is some parameters need to be fixed to fully utilize hardware resources. This happens for intel graphics, while the output channel factor should be as close to 16 as possible to use virtual threads. I'm also trying to apply autotvm to Intel graphics, and will see how graph tuner work in this case.
thanks @kevinthesun, it makes a lot more sense now. I'm looking forward to testing this feature on my network, where all inputs are NCHW8c.
I agree with @merrymercy in that the result of CGO 2018 paper seems fishy, but I did like their problem formulation. As long as we can define a 'node cost' and 'edge cost', we can model this as a standard discrete optimization problem that can be solved by off the shelf solvers. We can incorporate different algorithms each with different layout preference straightforwardly.
In the section 7 of the paper, they stated the following. To me, this sounds a lot like what AutoTVM + Graph tuning will enable to do. Very interesting!
A viable future approach might be to use code generators and auto-tuners to generate the code
and data layouts for given layers and use our approach to combine these code segments to
implement an entire DNN.
@merrymercy @eqy I have an issue dealing with FallbackContext for x86. Currently x86 conv2d will automatically generate default schedules if no pre-tuned schedules are provided. I think this should correspond to FallbackContext? However, even if I didn't specify any dispatch before compilation, when I called autotvm.task.DispatchContext.current inside x86.conv2d.py, it returns ApplyHistoryBest. Does autotvm automatically convert FallbackContext to ApplyHistoryBest somewhere?
Yes, in nnvm.build.compiler, we will load tophub context (which is an empty ApplyHistoryBest in your case)
But this is not a problem. Although current context is ApplyHistoryBest, when you query it, you cannot find the config for your workload. Then it will query its upper context, which is the root fallback context and returns a FallbackConfigEntity
. You can use cfg.is_fallback
to check whether it is a fallback. Don't use something like isinstance(autotvm.task.DispatchContext.current, autotvm.FallbackContext)
@kevinthesun The performance data using auto tuning or not? According comments, seems we don't apply, wish to update the performance data.
@FrozenGene The default schedule here for x86 eliminates most layout transformations. It should have similar performance with "apply_history_best". I'll update the data for "apply_history_best".
@FrozenGene Data of "apply_history_best" updated. @yzhliu Updated some implementation details.
Motivation
Currently we can tune operator with handcrafted schedules or AutoTVM, which can give us descent kernel performance. However, usually these kernel templates involve transform operator data layout to other formats, such as conv2d for Intel and ARM CPU. In this case, a lot of layout transformation can be introduced into graph. Graph tuner considers both fast kernel schedules and extra layout transformations, generating descent end to end performance.
Design
There are two steps to achieve graph level optimal schedules. First, get a set of schedule candidates for each workload in the graph. This step can be finished by AutoTVM Second, feed schedule candidates into graph tuner and run graph level tuning.
In the second step, graph tuner will benchmark all possible layout transformations in the graph, given a set of schedule candidates, and then combine schedule and layout transformation execution time together to find the optimal schedule combination.
The current solution for this optimization problem is to model it as a Markov Decision Process and use dynamic programming to solve it. This will give us global optimal solution. However, the time/memory complexity for DP is prohibitively expensive for some networks, such as SSD. In this case, we need to use approximation methods. For now graph tuner provides a graph coloring algorithm(PBQP) in this scenario.
API
We provide a base class and built-in subclass:
Implementation details: One key part of graph tuner is to generate all possible layout transformations given a set of workloads and schedules. To hide any operator related information from graph tuner, we can add a new generic function bind with topi operator, which accepts workload and cfg, and returns i/o shapes/layouts. An example for conv2d:
Note that although graph tuner only supports target op with single input and output(conv2d, conv2d_transpose, dense, etc), we make this api generic enough to support operators with multiple io.
With this function, in graph tuner we only need a dictionary mapping topi function name to corresponding infer layout function, similar to what autotvm task extraction function does. After shape and layout info is extracted, we can use a generic graph traversal to fetch all possible layout transformations.
Performance Benchmark
Intel Xeon CPU(AWS c5.9xlarge, 18 physical cores)
More benchmark data to be added.
PR: https://github.com/dmlc/tvm/pull/1586