Closed tqchen closed 6 years ago
I think @adityaatluri / AMD Folks will have the most relevant experience as to whether we need embedded asm for AMD GPUs, which are an interesting target. The fact that some of the kernels in MIOpen (e.g., Winograd) do not have exposed source suggests that the asm implementation yields the best performance.
Can any of the AMD folks comment on whether this is just an artifact of how kernels are typically written, limitations in the expressiveness of OpenCL/rocm, limitations in the compiler support of OpenCL/rocm, or some mixture of reasons?
Anecdotally, we have been able unlock very good performance on NVIDIA GPUs with relatively simple schedules for small batches (especially for direct convolution), but simply retuning these schedules on AMD GPUs has not been enough to get similar utilization. This could be because the CUDA compiler does some great optimizations for NVIDIA GPUs, AMD GPUs require special optimizations not covered in CUDA schedules, the OpenCL/rocm compiler for AMD GPUs doesn't do some equivalent optimizations, or (again) a mixture of reasons.
Hi folks - I'm interested in (mobile) CPU backend improvements for TVM.
What I've just implemented is slightly more general than what you suggested - patching llvm_module.cc
to support loading external LLVM IR modules (via a registered callback, like ROCM case), and then a bit of Python helper logic for passing these to tensor intrinsics. So, you can then write essentially arbitrary C or assembly microkernels - simple example:
void vadd_16_avx2(const float *x, int32_t x_off, const float *y,
int32_t y_off, float *z, int32_t z_off) {
for (size_t i = 0; i < 16; i += 8) {
_mm256_storeu_ps(z + z_off + i,
_mm256_add_ps(_mm256_loadu_ps(x + x_off + i),
_mm256_loadu_ps(y + y_off + i)));
}
}
and have that passed into a tensor op declaration like:
def intrin_vadd_16():
x = tvm.placeholder((16,), name='vx')
y = tvm.placeholder((16,), name='vy')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
irb = tvm.ir_builder.create()
extern_call = tvm.call_extern(
"int32",
"vadd_16_avx2",
irb.buffer_ptr(xx),
xx.elem_offset,
irb.buffer_ptr(yy),
yy.elem_offset,
irb.buffer_ptr(zz),
zz.elem_offset)
irb.emit(extern_call)
return irb.get()
return irbb
with tvm.build_config(offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func)
and then use that in a TVM expression like any other tensor intrinsic. The microkernel can get (optionally) completely inlined into the main TVM loop, etc, inline assembly works, etc. It seems to work correctly as well with simple matmul kernels as well. Is that similar to what you had in mind?
One thought for an intermediate goal is to get a TVM GEMM that is actually within 5% of performance of MKL or Accelerate on AVX2. Currently even with the best schedules I could find in https://github.com/dmlc/tvm/blob/master/tutorials/optimize/opt_gemm.py, it's ~25-50% slower than MKL or Accelerate on my I7-4980HQ, and from looking at the generated code I suspect we need more control over the microkernels to beat that last gap. Are you folks interested in this direction as well?
Another example is taking the Conv(1x64x56x56, 64x64x1x1) from ResNet 18, the best TVM performance I could get on AVX2 hits 35GFLOP/s single threaded, while the equivalent GEMM with a BLAS library is 2x faster at 70GFLOP/s. I'd have thought that it'd be possible to bridge this gap by having some architecture-specific microkernels that can be directly tensorized. Does that make sense to you?
@ajtulloch Patching llvm module seems fine as long as the inline optimization happens at link time, there could be a chance that inlining did not happen or did not benefit from too much early optimizations(edit, for CPU it seems not to be a problem, for GPUs we might need to bring the link to an earlier stage).
It is a great idea to support directly writing C code, as long as clang is available for preprocessing, this would be quite convenient
The general logic makes sense, and I guess the same way would work if we bring it to the codegen phase so that it works for cases like GPU. I still think it should belong to the codegen phase.
Would you be interested in discussing possible API -- e.g. putting things under build_config (we could add a setter of a thread local context, or just patch build config context) so so we can pass them as arguments.
The exploration of speedup via tensorizing micro-kernels is definitely interesting.
Given the heated discussion, I would suggest we have a quick discussion on API and patch the TVM to support this feature(should not cost too many loc)
As an update - with this microkernel
void sgemm_4x24__avx2(int32_t k, const float *a, int32_t a_off,
const float *b, int32_t b_off, float *c,
int32_t c_off, int32_t ldc) {
a = a + a_off;
b = b + b_off;
c = c + c_off;
size_t k_size_t = k;
size_t ldc_size_t = ldc;
asm volatile("shl $0x2,%[ldc_size_t]\n\t"
"prefetcht0 (%[c])\n\t"
"add %[ldc_size_t],%[c]\n\t"
"prefetcht0 (%[c])\n\t"
"add %[ldc_size_t],%[c]\n\t"
"prefetcht0 (%[c])\n\t"
"add %[ldc_size_t],%[c]\n\t"
"prefetcht0 (%[c])\n\t"
"vzeroall\n\t"
"LOOP_START%=:\n\t"
"vmovaps (%[b]),%%ymm3\n\t"
"vmovaps 0x20(%[b]),%%ymm2\n\t"
"vmovaps 0x40(%[b]),%%ymm1\n\t"
"add $0x60,%[b]\n\t"
"vbroadcastss (%[a]),%%ymm0\n\t"
"vfmadd231ps %%ymm3,%%ymm0,%%ymm8\n\t"
"vfmadd231ps %%ymm2,%%ymm0,%%ymm9\n\t"
"vfmadd231ps %%ymm1,%%ymm0,%%ymm10\n\t"
"vbroadcastss 0x4(%[a]),%%ymm0\n\t"
"vfmadd231ps %%ymm3,%%ymm0,%%ymm11\n\t"
"vfmadd231ps %%ymm2,%%ymm0,%%ymm12\n\t"
"vfmadd231ps %%ymm1,%%ymm0,%%ymm13\n\t"
"vbroadcastss 0x8(%[a]),%%ymm0\n\t"
"vfmadd231ps %%ymm3,%%ymm0,%%ymm14\n\t"
"vfmadd231ps %%ymm2,%%ymm0,%%ymm15\n\t"
"vfmadd231ps %%ymm1,%%ymm0,%%ymm7\n\t"
"vbroadcastss 0xc(%[a]),%%ymm0\n\t"
"vfmadd231ps %%ymm3,%%ymm0,%%ymm6\n\t"
"vfmadd231ps %%ymm2,%%ymm0,%%ymm5\n\t"
"vfmadd231ps %%ymm1,%%ymm0,%%ymm4\n\t"
"add $0x10,%[a]\n\t"
"dec %[k_size_t]\n\t"
"jne LOOP_START%=\n\t"
"vmovups %%ymm6,(%[c])\n\t"
"vmovups %%ymm5,0x20(%[c])\n\t"
"vmovups %%ymm4,0x40(%[c])\n\t"
"sub %[ldc_size_t],%[c]\n\t"
"vmovups %%ymm14,(%[c])\n\t"
"vmovups %%ymm15,0x20(%[c])\n\t"
"vmovups %%ymm7,0x40(%[c])\n\t"
"sub %[ldc_size_t],%[c]\n\t"
"vmovups %%ymm11,(%[c])\n\t"
"vmovups %%ymm12,0x20(%[c])\n\t"
"vmovups %%ymm13,0x40(%[c])\n\t"
"sub %[ldc_size_t],%[c]\n\t"
"vmovups %%ymm8,(%[c])\n\t"
"vmovups %%ymm9,0x20(%[c])\n\t"
"vmovups %%ymm10,0x40(%[c])\n\t"
"vzeroupper\n\t"
: [c] "+r"(c), [b] "+r"(b), [a] "+r"(a),
[k_size_t] "+r"(k_size_t), [ldc_size_t] "+r"(ldc_size_t)
:
: "cc", "memory", "%ymm0", "%ymm1", "%ymm2", "%ymm3", "%ymm4",
"%ymm5", "%ymm6", "%ymm7", "%ymm8", "%ymm9", "%ymm10",
"%ymm11", "%ymm12", "%ymm13", "%ymm14", "%ymm15");
}
I get essentially equivalent performance (+/-10%) to Accelerate/MKL single-threaded on most GEMMs, without any scheduling work involved (i.e. just constructing full panels separately, then tensorizing the reduction) - and presumably a scheduling guru could get this even better. AFAIK I haven't seen any TVM schedules that can get within 30% of this number on my i7-4980HQ - if there are, please let me know.
I think this could be a pretty interesting for FP32 NN performance - e.g. this could be directly used to improve 1x1 perf for e.g. FP32 MobileNet on AVX2.
I think in terms of a TVM API, it'd be nice to directly support C/ASM source files that can be automatically compiled and linked against (handling cases like just standard intrinsics or C functions, using inline assembly, and even directly using ASM from an assembler like PeachPy, etc). I think something like:
with tvm.build_config(extern_cpp_source=["/path/to/kernels.cpp"], extern_asm_source=["path/to/kernels.S"]):
extern_call = tvm.call_extern("int32", "sgemm_only_4x24__avx2") # defined in `kernels.cpp`
...
could be pretty usable. I'm still very new to this project so I don't have any strong opinions here TBH - I just think this is a pretty useful bit of functionality for getting peak performance out of this system :)
@ajtulloch for tvm, are you using the target "llvm -mcpu=core-avx2"? With the target "llvm", you only get SSE even if your cpu supports avx2.
OK, how about we support the following API
with tvm.build_config(extern_llvm_source=[
"/path/to/kernel.ll",
my_llvm_str,
contrib.clang.compile_llvm(open("/path/to/kernel.cpp").read()),
my_tools_decorate_asm_to_llvm(open("/path/to/myasm.S").read())]):
pass
@ajtulloch @eqy
That looks great @tqchen.
@masahi, @tqchen for example, check https://gist.github.com/ajtulloch/0087cb2dee2fab580ec60cc87220fbde. With this, I see (running MKL_NUM_THREADS=1 VECLIB_MAXIMUM_THREADS=1 TVM_NUM_THREADS=1 python tutorials/optimize/opt_gemm.py 2>&1 | grep Opt
) - all numbers are in GFLOP/s.
Opt1: 4.909991
Opt2: 23.436729
Opt3: 44.081897
Opt4: 50.382429
Opt5: 60.753072
Opt6: 59.855750
OptBLAS: 88.236252
OptTensorize: 88.171094
The microkernel peaks at about 100GFLOP/s, so this is within ~10% of peak performance for this CPU.
This represents about a 40% speedup (60GFLOP/s -> 88GFLOP/s) over the previous best Opt routine, and equal to the vendor-tuned BLAS (as desired).
Incidentally, the schedule for the Tensorize version is:
// attr [APanel] storage_scope = "global"
allocate APanel[float32 * 192 * 768 * 4]
// attr [BPanel] storage_scope = "global"
allocate BPanel[float32 * 32 * 768 * 24]
produce APanel {
for (mtile, 0, 192) {
for (k, 0, 768) {
APanel[(((mtile*768) + k)*4)] = A[((mtile*3072) + k)]
APanel[((((mtile*768) + k)*4) + 1)] = A[(((mtile*3072) + k) + 768)]
APanel[((((mtile*768) + k)*4) + 2)] = A[(((mtile*3072) + k) + 1536)]
APanel[((((mtile*768) + k)*4) + 3)] = A[(((mtile*3072) + k) + 2304)]
}
}
}
produce BPanel {
for (ntile, 0, 32) {
for (k, 0, 768) {
for (n, 0, 24) {
BPanel[((((ntile*768) + k)*24) + n)] = B[(((ntile + (k*32))*24) + n)]
}
}
}
}
produce C {
for (m.outer, 0, 24) {
for (n.outer, 0, 32) {
for (m.inner.outer, 0, 8) {
sgemm_only_4x24__avx2(768, APanel, (((m.outer*8) + m.inner.outer)*3072), BPanel, (n.outer*18432), C, ((((m.outer*1024) + n.outer) + (m.inner.outer*128))*24), 768)
}
}
}
}
and I'm sure it could be improved as well.
@tqchen since you can't really go from ASM to LLVM IR, what about also supporting linking against object files directly? That's the only way to support a PeachPy-style case right?
After thinking more carefully about the problem. Here is a new alternative proposal. The idea is to introduce module importing in the IR
for i in range(10):
attr pragma import "/path/to/file.ll"
attr pragma import ll_source_in_str
We can make the llvm code generator to handle import gracefully, by link and inline the module. The advantage over the build config is that we can actually attach source code during schedule(via pragma), or directly insert it during the intrinsic creation stage.
For example, we can do things like, and the asm can be directly tied to tensor intrinsic. We will also be able to allow different target platforms to use different asm kernels, and user don't have to worry about it in build_config
def intrin_vadd_16():
x = tvm.placeholder((16,), name='vx')
y = tvm.placeholder((16,), name='vy')
z = tvm.compute(x.shape, lambda i: x[i] + y[i], name='z')
def intrin_func(ins, outs):
xx, yy = ins
zz = outs[0]
irb = tvm.ir_builder.create()
# Maybe need to update irb
irb.attr_scope("pragma", "import", "/path/to/xx.ll")
extern_call = tvm.call_extern(
"int32",
"vadd_16_avx2",
irb.buffer_ptr(xx),
xx.elem_offset,
irb.buffer_ptr(yy),
yy.elem_offset,
irb.buffer_ptr(zz),
zz.elem_offset)
irb.emit(extern_call)
return irb.get()
return irbb
with tvm.build_config(offset_factor=16):
return tvm.decl_tensor_intrin(z.op, intrin_func)
@tqchen I don’t quite understand the advantages of the second proposal. Could you flesh out the example a bit please?
@ajtulloch I updated my post. Also in terms of directly link ASM like PeachPy, do you think if it is easy to decorate the asm string with llvm syntax? Ideally if we could patch PeachPy to do it would be really great as it will give us ability to inline
@eqy @ajtulloch can you comment on the new embedded API by introducing import program to the IR? if you think that is good, we can proceed with implementations
@tqchen it looks good to me I think. I think this is going to be very useful for CPU performance (I've already got ~2x speedup for AVX 1x1 kernels vs the conv2d_avx_1x1.py ones).
@ajtulloch awesome, can you open a PR to bring the support in? I have will change the issue to WIP and list the action items
@ajtulloch to comment on the AVX-2/AVX-512 schedules currently in tvm: I believe these are currently tuned mainly for high core count c5 AVX-512 instances, where the end2end performance should be better than with mkldnn on Resnet-18. Retuning them for AVX-2 can recover some performance, but overall I think the current schedules heavily target AVX-512 over AVX-2 (@yzhliu knows more details).
@eqy, @yzhliu, @tqchen I'll send out my results. I'm mostly interested in performance with a) relatively low core counts, b) AVX-2 (Haswell/Broadwell) architectures, which is where these results are coming from. e.g. for the end-to-end MobileNet (since I've only done this for 1x1 convs), I get about a 30% speedup (0.04s -> 0.028s) from just microkernels + tensorize (with a very dumb schedule, e.g. fully instantiating the intermediate panel x panels tensor due to tensorize limitations).
Do you know what the best result + schedule anyone has obtained (single-threaded) for relatively large GEMMs (e.g. M=N=K={768,1024}) for TVM on AVX-2 machines? I can't seem to get higher than 60 GFLOP/s (which is also the peak I get from LLVM-generated microkernels due to bad register allocation + scheduling), whereas tensorize + microkernels gets to like 90GFLOP/s comfortably. I haven't looked much at the AVX-512 code that LLVM generates but if it's similarly bad then I'd have thought we could get similar gains on those architectures.
I believe we do have room for AVX2 optimization, as well as the (AVX-512/AVX2) 1x1-kernel convolution. Looking forward to the numbers & implementation.
Cool, sounds good. Re code, https://github.com/ajtulloch/tvm/tree/tvm-tulloch is my WIP branch - key changes (compute + schedule for 1x1 conv) is https://github.com/ajtulloch/tvm/blob/5c93b249a961d522265de36a007ed16b24d3e728/topi/python/topi/x86/conv2d_avx_1x1.py.
Current numbers:
(bs, in_channel, in_size, num_filter, kernel, stride, padding) |
AVXConv2DTensorize GFLOP/s |
AVXConv2D GFLOP/s |
---|---|---|
(1, 128, 28, 256, 1, 1, 0) |
66.171627 | 38.920085 |
(1, 64, 56, 128, 1, 1, 0) |
38.293128 | 27.431091 |
(1, 128, 28, 128, 1, 1, 0) |
58.991617 | 36.568366 |
(1, 128, 28, 256, 1, 1, 0) |
67.684656 | 38.701075 |
(1, 256, 14, 256, 1, 1, 0) |
73.674180 | 35.469258 |
(1, 256, 14, 512, 1, 1, 0) |
70.014039 | 38.548551 |
(1, 256, 14, 512, 1, 1, 0) |
65.487479 | 35.186892 |
(1, 512, 7, 512, 1, 1, 0) |
47.640705 | 36.984628 |
I have modified the first post to list the actionable items. This issue is changed to WIP. @ajtulloch would you be interested in PR?
Thanks @ajtulloch Have you tried other kernel size? I'm curious whether assembly intrinsic can improve avx_common as well.
@yzhliu it seems like it can - I did some experiments with simple im2col + gemm, using the tensorize + microkernel for the gemm. This has a geometric mean speedup (from the workloads in https://github.com/dmlc/tvm/blob/master/topi/python/topi/x86/conv2d.py#L16 and https://github.com/dmlc/tvm/blob/master/topi/python/topi/nn/conv2d.py#L24), of:
Network | Speedup (geometric mean) |
---|---|
ResNet-50 | 169% |
ResNet-18 | 56% |
MobileNet | 134% |
Check out https://github.com/ajtulloch/tvm/blob/nhwc-optimization/topi/python/topi/x86/conv2d_avx_nhwc.py. You can replicate by running https://github.com/ajtulloch/tvm/blob/4f624f888a78534b7491459e8667c192777d9880/topi/python/topi/x86/conv2d_avx_nhwc.py with TVM_NUM_THREADS=1 python topi/python/topi/x86/conv2d_avx_nhwc.py
.
Currently that naive im2col/gemm implementation could be improved in a bunch of ways as well - we don't currently block in the reduction which hurts performance for large reductions (KH KW C).
Raw data is:
N: 1, CIn: 512, H/W: 7, COut: 1024, KH/KW: 1
BaselineNHWC: 19.45, BaselineNCHW: 41.12, TensorNHWC: 59.23
N: 1, CIn: 1024, H/W: 7, COut: 1024, KH/KW: 1
BaselineNHWC: 17.72, BaselineNCHW: 38.49, TensorNHWC: 52.74
N: 1, CIn: 32, H/W: 112, COut: 64, KH/KW: 1
BaselineNHWC: 63.10, BaselineNCHW: 13.36, TensorNHWC: 46.64
N: 1, CIn: 64, H/W: 56, COut: 128, KH/KW: 1
BaselineNHWC: 48.82, BaselineNCHW: 13.72, TensorNHWC: 63.89
N: 1, CIn: 128, H/W: 56, COut: 128, KH/KW: 1
BaselineNHWC: 35.01, BaselineNCHW: 14.88, TensorNHWC: 64.69
N: 1, CIn: 128, H/W: 28, COut: 256, KH/KW: 1
BaselineNHWC: 17.04, BaselineNCHW: 40.19, TensorNHWC: 82.56
N: 1, CIn: 256, H/W: 28, COut: 256, KH/KW: 1
BaselineNHWC: 16.29, BaselineNCHW: 41.91, TensorNHWC: 86.27
N: 1, CIn: 256, H/W: 14, COut: 512, KH/KW: 1
BaselineNHWC: 16.31, BaselineNCHW: 41.50, TensorNHWC: 73.89
N: 1, CIn: 512, H/W: 14, COut: 512, KH/KW: 1
BaselineNHWC: 16.07, BaselineNCHW: 40.55, TensorNHWC: 80.26
MOBILENET: 2.34
N: 1, CIn: 3, H/W: 224, COut: 64, KH/KW: 7
BaselineNHWC: 45.61, BaselineNCHW: 12.43, TensorNHWC: 39.06
N: 1, CIn: 64, H/W: 56, COut: 64, KH/KW: 3
BaselineNHWC: 45.32, BaselineNCHW: 50.42, TensorNHWC: 42.15
N: 1, CIn: 64, H/W: 56, COut: 64, KH/KW: 1
BaselineNHWC: 60.32, BaselineNCHW: 14.33, TensorNHWC: 53.53
N: 1, CIn: 64, H/W: 56, COut: 128, KH/KW: 3
BaselineNHWC: 21.33, BaselineNCHW: 51.49, TensorNHWC: 46.77
N: 1, CIn: 64, H/W: 56, COut: 128, KH/KW: 1
BaselineNHWC: 36.88, BaselineNCHW: 26.19, TensorNHWC: 65.28
N: 1, CIn: 128, H/W: 28, COut: 128, KH/KW: 3
BaselineNHWC: 21.58, BaselineNCHW: 52.02, TensorNHWC: 60.27
N: 1, CIn: 128, H/W: 28, COut: 256, KH/KW: 3
BaselineNHWC: 21.72, BaselineNCHW: 38.94, TensorNHWC: 49.22
N: 1, CIn: 128, H/W: 28, COut: 256, KH/KW: 1
BaselineNHWC: 15.84, BaselineNCHW: 29.45, TensorNHWC: 78.21
N: 1, CIn: 256, H/W: 14, COut: 256, KH/KW: 3
BaselineNHWC: 21.65, BaselineNCHW: 38.56, TensorNHWC: 58.04
N: 1, CIn: 256, H/W: 14, COut: 512, KH/KW: 3
BaselineNHWC: 16.93, BaselineNCHW: 35.30, TensorNHWC: 39.40
N: 1, CIn: 256, H/W: 14, COut: 512, KH/KW: 1
BaselineNHWC: 16.35, BaselineNCHW: 39.39, TensorNHWC: 64.80
N: 1, CIn: 512, H/W: 7, COut: 512, KH/KW: 3
BaselineNHWC: 13.54, BaselineNCHW: 30.03, TensorNHWC: 26.24
RESNET-18: 1.56
N: 1, CIn: 64, H/W: 56, COut: 256, KH/KW: 1
BaselineNHWC: 17.29, BaselineNCHW: 18.88, TensorNHWC: 67.97
N: 1, CIn: 256, H/W: 56, COut: 64, KH/KW: 1
BaselineNHWC: 38.36, BaselineNCHW: 18.12, TensorNHWC: 65.66
N: 1, CIn: 256, H/W: 56, COut: 128, KH/KW: 1
BaselineNHWC: 31.74, BaselineNCHW: 16.42, TensorNHWC: 74.59
N: 1, CIn: 128, H/W: 28, COut: 512, KH/KW: 1
BaselineNHWC: 15.72, BaselineNCHW: 19.93, TensorNHWC: 82.93
N: 1, CIn: 256, H/W: 56, COut: 512, KH/KW: 1
BaselineNHWC: 15.27, BaselineNCHW: 17.82, TensorNHWC: 87.61
N: 1, CIn: 512, H/W: 28, COut: 128, KH/KW: 1
BaselineNHWC: 20.92, BaselineNCHW: 19.08, TensorNHWC: 73.68
N: 1, CIn: 512, H/W: 28, COut: 256, KH/KW: 1
BaselineNHWC: 16.95, BaselineNCHW: 28.22, TensorNHWC: 70.45
N: 1, CIn: 256, H/W: 14, COut: 1024, KH/KW: 1
BaselineNHWC: 20.53, BaselineNCHW: 40.72, TensorNHWC: 77.74
N: 1, CIn: 512, H/W: 28, COut: 1024, KH/KW: 1
BaselineNHWC: 20.01, BaselineNCHW: 30.35, TensorNHWC: 77.84
N: 1, CIn: 1024, H/W: 14, COut: 256, KH/KW: 1
BaselineNHWC: 16.24, BaselineNCHW: 37.88, TensorNHWC: 69.48
N: 1, CIn: 1024, H/W: 14, COut: 512, KH/KW: 1
BaselineNHWC: 16.45, BaselineNCHW: 23.24, TensorNHWC: 60.36
N: 1, CIn: 512, H/W: 7, COut: 2048, KH/KW: 1
BaselineNHWC: 18.67, BaselineNCHW: 33.24, TensorNHWC: 47.70
N: 1, CIn: 1024, H/W: 14, COut: 2048, KH/KW: 1
BaselineNHWC: 13.87, BaselineNCHW: 23.35, TensorNHWC: 35.61
N: 1, CIn: 2048, H/W: 7, COut: 512, KH/KW: 1
BaselineNHWC: 14.35, BaselineNCHW: 26.08, TensorNHWC: 46.14
RESNET-50: 2.69```
@ajtulloch I quickly tried your tvm-tulloch to reproduce the gemm results but failed. Can you rebase your code to the latest and make sure python opt_gemm.py
would work? Thanks!
FYI, the errors I got are something like " Cannot find function tvm.contrib.cblas.matmul in the imported modules or global registry" and "Fail to load bitcode file gemmMxN__avx2.bc".
@yidawang you need to enable cblas during cmake.
@masahi thanks! This fixed the first error, but the second one remains.
@yidawang you probably need to regenerate the .bc file with your current clang version (and ensure it’s in your $CWD?). clang gemmxN__avx2.c -O3 -march=core-avx2 -emit-llvm -o gemmMxN__avx2.bc
or similar should work?
@ajtulloch Thanks, setting $CWD works (with some warnings though, e.g. *** is not a recognized feature for this target. I didn't look into them yet). The performance looks promising! But looks like it only works for the matrix shape you gave?
Sorry I missed the party. @eqy Yes, assembly kernels are proven to be high performant than OpenCL or HIP kernels. @tqchen I am working on generating microkernels which are llvm basic blocks, the kernel I am generating are using way more registers than it supposed to. The basic one for 4kx4kx4k gemm I am seeing 9TFLOPs on Vega64 (peak is 12). I have made few changes today to bring register count down. Hopefully by end of this weekend I'll have something solid. I'll contact you about API design later next week.
@yidawang check out the nhwc-optimization branch in that repo, it has generic convolutions built on these microkernels (1x1 and MxN, arbitrary size/stride/padding) which have similar efficiency benefits.
@ajtulloch would you be interested in contribute a PR to support the first part of the action item(as updated in the first post of this issue).
@tqchen sure, me (or someone I loop in) will work on that part.
@tqchen got a version of micro-kernel generator working. I am seeing peak throughput (12TFLOPS) on vega64. I am making it a proper library now, code will be hosted here
@ajtulloch what is the status on this? If you haven't started yet, I may have some free cycles to get it in this week.
close by #1486 let us open up followup issues for potential proposed improvements and micro-kernel exploration. Thanks @ajtulloch @eqy @adityaatluri @yidawang @yzhliu for helpful discussions in this RFC
Ah sorry @tqchen, I was out the last few weeks. This looks great, thanks so much. I'm restarted working on these tensorized implementations in https://github.com/ajtulloch/tvm/blob/43f50eba9eab45ec505e92ce372c45d91227778a/tensorize/gemm__neon.c + https://github.com/ajtulloch/tvm/blob/43f50eba9eab45ec505e92ce372c45d91227778a/tensorize/gemm__avx2.c, and there's so pretty promising results on ARMv7 in addition to AVX2 already.
Now that the feature is upstreamed to the main branch, let us continue to explore the direction on what benefit we can bring. @ajtulloch 's result looks encouraging. I created a thread in https://discuss.tvm.ai/t/inline-assembly-micro-kernel-and-performance-experiment-for-arm-and-x86/527 for followup discussions on this
@ajtulloch can you share performance numbers?
This issue arises through discussion with several people @eqy @adityaatluri @ajtulloch @cowanmeg So far we can use IR to embed a sequence of LLVM intrinsics to generate microkernel. But this may not be ideal, in certain cases, we need to make use of hand-crafted micro-kernels, or utilize assembly generator such as Peachey to generate and embed them into the code. This issue is used to discuss possible API proposals and ways to do so.
Notably, llvm already support inline assembly, so the only gap is to add support in the IR intrinsic side. There are two possible ways to make this happen
Link to a module with assembly intrinsic
This is the easiest approach, @eqy already have some hacked up version that does this. The idea is that load the asm module as a separate LLVM module (wrapped in llvm ir), set the property so that this module must be force inlined, and load this module in code generation phase and link.
To implement this, we can specify the desired linked external module (possibly via arguments, or a thread-local scope context setting function that set the linked modules)
Ideally, we want to support inline asm string, as well as a path to the asm file.
One key characteristic of this approach is that all the asm are wrapped as functions. The LLVM's SSA property means it might be hard to do certain things, like the array of registers have to be passed by address and we need to rely on the register promotion pass. It is not sure if there are cases we could not handle
Action Items
pragma_import
.unordered_set<std::string>
(deduplication)pragma_import
(schedule, and tensor intrinsic)sch[xx].pragma("import", "path/to/file.ll")
contrib.clang.create_llvm
to make it easy to generate llvm from embedded c asm.Can be in a separate PR