Samsung / ONE

On-device Neural Engine
Other
427 stars 151 forks source link

Generate GEMM kernel with NN Compiler technique #5951

Open chunseoklee opened 3 years ago

chunseoklee commented 3 years ago

Let's try to generate optimized GEMM kernel for Hybrid FC(https://github.com/Samsung/ONE/issues/5950). Note that Hybrid FC exploits (int8*int8) -> int32 type GEMM kernel.

wateret commented 3 years ago

I'll take a look at tiramisu.

periannath commented 3 years ago

TVM auto-tuning method

TVM has two ways to auto-tune kernel for model

Performance of GEMM kernel generated by TVM

NxKxM TVM TVM (Weight transform) ONE
1x256x6144 840 510 392
1x192x6144 460 330 265
1x128x6144 350 220 150
8x1536x64 230 190 94
4x1000x64 80 60 36
binarman commented 3 years ago

@periannath Hi! FYI I started implementation of Halide based compiled in #5836 You can see basic FC implementation here: https://github.com/Samsung/ONE/blob/5fba53110a0b6aa763fe5d3c4f34d73226910b6f/compiler/luci-codegen/src/KernelBuilder.cpp#L312

Also note that I used autoscheduling functionality of Halide for now (it shows reasonable performance).

periannath commented 3 years ago

@binarman Thanks! I'll look into it.

I found that Halide have three autoschedulers (Mullapudi2016, Li2018, and Adams2019). https://github.com/halide/Halide/releases/tag/v10.0.0

Which autoscheduler did you used?

binarman commented 3 years ago

Which autoscheduler did you used?

I tried all of them, but with different workload. 1) I used "arithmetic" graph, not FC: subgraph

2) I did not try to parallelize computations yet. For now my main goal is to optimize one threaded computations. Rationale is following: I want to speedup subgraphs with lots of small operators (I think this is not the best target for paralleization), reduce runtime overhead for low-end devices (like microcontrollers, they often do not have multiple cores, so nothing to parallelize).

I tried to run model above on Galaxy S20 on fastest code, my results are:

Scheduler Time in microseconds
None ~7.80
Mullapudi2016 ~2.01
Li2018 ~500
Adams2019 ~2.16

Li2018 is extremely slow because it emits parallelization/syncronization code despite the fact I configured it with 1 available processor.

binarman commented 3 years ago

@periannath FYI Compiler I am working now currently have no convenient switches to choose scheduler (you need to fix it in code and recompile), but I hope I'll fix it soon.

I am also want to note that all computations I tested til now was in float. I am not not sure It will work correctly with FC layers described in this issue now.

Could you tell more about what is int8xint8 -> int32 FC, in particular: does it involve quantization of output?

binarman commented 3 years ago

@periannath @wateret Sorry, another question about this FC =_= I want to test it using my draft.

Could you clarify where to make casts (int8 -> int32)? I can imagine two option: 1) output += (int32)weights * (int32)input 2) output += (int32)(weights * input)

periannath commented 3 years ago

@binarman Thank you for sharing your numbers. I tried to use autoschedule in odroid-xu4 and Adams2019 works better for me.

Could you tell more about what is int8xint8 -> int32 FC, in particular: does it involve quantization of output?

Hybrid FullyConnected kernel in TFLite gets float input and int8 weight. And its result is float. The kernel quantizes input to int8 type in inference time and do int8*int8 computation. The result is int32 type tensor and it is transformed into float type.

We're trying to generate better kernel for (int8*int8) -> int32 type GEMM in hybrid FC.

You might read @chunseoklee's analysis on hybrid FC link.

Could you clarify where to make casts (int8 -> int32)? I can imagine two option:

1. `output += (int32)weights * (int32)input`
2. `output += (int32)(weights * input)`

~AKAIK, option 2 is right but TVM can't support it. I used option 1 for TVM.~

Updated : TFLite uses option 1 with int16 type. https://github.com/Samsung/ONE/issues/5951#issuecomment-785652647

periannath commented 3 years ago

Performance of GEMM kernel generated by Halide

Configuration

Device : Odroid-XU4

NxKxM ONE Halide
1x256x6144 392 2067
1x192x6144 265 1493
1x128x6144 150 993
8x1536x64 94 1300
4x1000x64 36 421

Device : Raspberry PI 3 B+

NxKxM ONE Halide
1x256x6144 2619 9896
1x192x6144 2058 7160
1x128x6144 1426 4651
8x1536x64 544 6496
4x1000x64 187 2111

Updated (02-23)

binarman commented 3 years ago

@periannath First of all, thank you for your answers!

Second, I want to clarify this item:

Halide produced better kernel for float GEMM #5440 (comment)

This result was achieved: 1) with manual scheduling 2) with 1xKxM (N == 1) operations. 3) with galaxy s20 phone (It has significantly newer hardware compared RPI3 and Odroid)

So, I think I'll retest this operations in new environment. (probably today)

binarman commented 3 years ago

@periannath Sorry for disturbing you again, but I want to clarify two items:

1) I investigated tflite code and found only case 1 (output += (int32)weights * (int32)input) Could you check it again, please?

This should have performance impact on generated code.

2) I noticed that you address Halide buffers in your code like it is done in C/C++. AFAIK in Halide uses reverse order of dimensions, so first dimension is the fastest changing. For example to represent int a[N][M]; you should use Halide::Buffer<int> a(M, N);

periannath commented 3 years ago

@binarman

1) I investigated tflite code and found only case 1 (output += (int32)weights * (int32)input) Could you check it again, please?

AFAIK, TFLite neon kernel uses case 2 for hybrid FC. Here is the neon kernel code for hybrid FC.

https://github.com/tensorflow/tensorflow/blob/43d85bc8bbfe984c0b8185a575fa25d6dff26271/tensorflow/lite/kernels/internal/optimized/neon_tensor_utils.cc#L804-L819

The input and weight are loaded in int8x16_t neon registers. It multiplied low bits into int16x8_t neon register and do the same thing for high bits. Result is accumulated into int32x4_t neon register.

In c++ kernel, TFLite kernel uses code like int output = (int8_t)weight * (int8_t)input;. I thought it works as case 2 but it wasn't. Thanks to you, I've noticed that int8_t is treated as int type by implicit type conversion https://en.cppreference.com/w/c/language/conversion#Integer_promotions.

In summary

If neon is supported on target device, TFLite will use neon kernel as default.

2) I noticed that you address Halide buffers in your code like it is done in C/C++. AFAIK in Halide uses reverse order of dimensions, so first dimension is the fastest changing. For example to represent int a[N][M]; you should use Halide::Buffer<int> a(M, N);

Thanks! I'll measure the performance again after I've fixed it.

binarman commented 3 years ago

@periannath I checked neon kernel you mentioned. It uses case 1 too. Please, take a look:

// load vector of 16 int8
const int8x16_t s1_8x16 = vld1q_s8((const int8_t*)(aligned_vec + col));

// load vector of 16 int8
const int8x16_t s2_8x16 = vld1q_s8((const int8_t*)(row_ptr + col));

// mult two int8 vectors and store result in vector of int16
int16x8_t prod_16x8 = vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));

// mult two int8 vectors and add result in previously computed vector of int16
prod_16x8 = vmlal_s8(prod_16x8, vget_high_s8(s1_8x16), vget_high_s8(s2_8x16));

// add pairs of elements in prod_16x8, and and store them in 32-bit integer vector
dotprod_32x4 = vpadalq_s16(dotprod_32x4, prod_16x8);

You can test it with extracted code from kernel: arm_long_test.zip

P.s. This is because kernel uses "long" version of instruction. You can see it by additional "l" in the name: vmlal_s8 vs vmla_s8,vmull_s8 vs vmul_s8

periannath commented 3 years ago

@binarman I wasn't aware of long version in arm intrinsic. Thanks for letting me know.

I've checked neon kernel in detail. As you said, kernel uses case 1 but with int16_t type.

// mult two int8 vectors and store result in vector of int16
int16x8_t prod_16x8 = vmull_s8(vget_low_s8(s1_8x16), vget_low_s8(s2_8x16));

To mimic this computation, Halide program should be changed as below.

out(y, x) = 0
out(y, x) += i16(in(r.x,x)) * i16(weight(r.x, y));

Assembly

     9f0:       f2c8ccaa        vmull.s8        q14, d24, d26
     9f4:       e35c0c01        cmp     ip, #256        ; 0x100
     9f8:       f2c98caa        vmull.s8        q12, d25, d26
     9fc:       f2d001ac        vaddw.s16       q8, q8, d28
     a00:       f2d661ad        vaddw.s16       q11, q11, d29
     a04:       f2d221a9        vaddw.s16       q9, q9, d25
     a08:       f2d441a8        vaddw.s16       q10, q10, d24

Assembly with i32 type

out(y, x) = 0
out(y, x) += i32(in(r.x,x)) * i32(weight(r.x, y));

Assembly

     9f0:       f2c8aa38        vmovl.s8        q13, d24
     9f4:       e35c0c01        cmp     ip, #256        ; 0x100
     9f8:       f2c88a39        vmovl.s8        q12, d25
     9fc:       ee8c1bb0        vdup.16 d28, r1
     a00:       f2da08ac        vmlal.s16       q8, d26, d28
     a04:       f2db68ac        vmlal.s16       q11, d27, d28
     a08:       f2d848ac        vmlal.s16       q10, d24, d28
     a0c:       f2d928ac        vmlal.s16       q9, d25, d28

I hope there are no more errors in this analysis. 😅

binarman commented 3 years ago

@periannath Thank you for your efforts! I think now everything works good =)