Zhen-Dong / HAWQ

Quantization library for PyTorch. Support low-precision and mixed-precision quantization, with hardware implementation through TVM.
MIT License
406 stars 83 forks source link

Similar running time with INT8 and INT4 #6

Closed haibao-yu closed 3 years ago

haibao-yu commented 3 years ago

Hi, I run resnet50 with uniform8 and uniform4, but they have a similar running time.

I run INT8 and INT4 as

#!/bin/bash

run_inference() {
        bit_config=$1
        num_layers=$2

        printf "%s\n" $bit_config

        python test_resnet_inference_time.py --bit-config $bit_config --num-layers $num_layers

        cp ./debug_output/resnet_generated.cu ./debug_output/resnet_manual.cu

        sed -i 's/h_w_fused_n_fused_i_fused_nn_fused_ii_fused_inner < 8;/h_w_fused_n_fused_i_fused_nn_fused_ii_fused_inner < 1;/g' ./debug_output/resnet_manual.cu
        sed -i 's/ax0_ax1_fused_ax2_fused_ax3_fused_inner < 8;/ax0_ax1_fused_ax2_fused_ax3_fused_inner < 1;/g' ./debug_output/resnet_manual.cu

        sleep 5
        python test_resnet_inference_time.py --bit-config $bit_config --num-layers $num_layers --manual-code
}

run_inference "bit_config_resnet50_uniform4"   50
run_inference "bit_config_resnet50_uniform8"   50

However, there is a similar running time with manual mode as

Performed inference in 17.05ms (std = 0.15) for 8 samples
Average per sample inference time: 2.13ms

and

Performed inference in 20.49ms (std = 0.27) for 8 samples
Average per sample inference time: 2.56ms
zachzzc commented 3 years ago

The inference time for both INT8 and INT4 look slow. What's the hardware platform you are running at? Do you see any warnings about TVM convolution workloads?

haibao-yu commented 3 years ago

The inference time for both INT8 and INT4 look slow. What's the hardware platform you are running at? Do you see any warnings about TVM convolution workloads?

I use Tesla T4 on Google Cloud and the warnings are as follows

WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_nhwc_tensorcore_im2col.cuda', ('TENSOR', (8, 230, 230, 3), 'int8'), ('TENSOR', (7, 7, 64, 3), 'int8'), (2, 2), (0, 0, 0, 0), (1, 1), 'NHWC', 'int32'). A fallback configuration is used, which may bring great performance regression.

The whole log is as follows (with little difference since I rerun this experiment)

WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_nhwc_tensorcore_im2col.cuda', ('TENSOR', (8, 230, 230, 3), 'int8'), ('TENSOR', (7, 7, 64, 3), 'int8'), (2, 2), (0$
 0, 0, 0), (1, 1), 'NHWC', 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (56, 56, 8, 64), 'uint4'), ('TENSOR', (1, 1, 64, 64), 'int4'), (1, 1), (0, 0, 
0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (58, 58, 8, 64), 'uint4'), ('TENSOR', (3, 3, 64, 64), 'int4'), (1, 1), (0, 0, 
0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (56, 56, 8, 64), 'uint4'), ('TENSOR', (1, 1, 256, 64), 'int4'), (1, 1), (0, 0$
 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (56, 56, 8, 256), 'uint4'), ('TENSOR', (1, 1, 64, 256), 'int4'), (1, 1), (0, $
, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (56, 56, 8, 256), 'uint4'), ('TENSOR', (1, 1, 128, 256), 'int4'), (2, 2), (0, 
0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (30, 30, 8, 128), 'uint4'), ('TENSOR', (3, 3, 128, 128), 'int4'), (1, 1), (0, 
0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (28, 28, 8, 128), 'uint4'), ('TENSOR', (1, 1, 512, 128), 'int4'), (1, 1), (0, 
0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (56, 56, 8, 256), 'uint4'), ('TENSOR', (1, 1, 512, 256), 'int4'), (2, 2), (0, 
0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (28, 28, 8, 512), 'uint4'), ('TENSOR', (1, 1, 128, 512), 'int4'), (1, 1), (0, 
0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (28, 28, 8, 512), 'uint4'), ('TENSOR', (1, 1, 256, 512), 'int4'), (2, 2), (0, 
0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (16, 16, 8, 256), 'uint4'), ('TENSOR', (3, 3, 256, 256), 'int4'), (1, 1), (0, 
0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (14, 14, 8, 256), 'uint4'), ('TENSOR', (1, 1, 1024, 256), 'int4'), (1, 1), (0$
 0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (28, 28, 8, 512), 'uint4'), ('TENSOR', (1, 1, 1024, 512), 'int4'), (2, 2), (0$
 0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (14, 14, 8, 1024), 'uint4'), ('TENSOR', (1, 1, 256, 1024), 'int4'), (1, 1), ($
, 0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (14, 14, 8, 1024), 'uint4'), ('TENSOR', (1, 1, 512, 1024), 'int4'), (2, 2), ($
, 0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (9, 9, 8, 512), 'uint4'), ('TENSOR', (3, 3, 512, 512), 'int4'), (1, 1), (0, 0$
 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (7, 7, 8, 512), 'uint4'), ('TENSOR', (1, 1, 2048, 512), 'int4'), (1, 1), (0, $
, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (14, 14, 8, 1024), 'uint4'), ('TENSOR', (1, 1, 2048, 1024), 'int4'), (2, 2), $
0, 0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (7, 7, 8, 2048), 'uint4'), ('TENSOR', (1, 1, 512, 2048), 'int4'), (1, 1), (0, 
0, 0, 0), (1, 1), 'int32'). A fallback configuration is used, which may bring great performance regression.
WARNING:autotvm:Cannot find config for target=cuda, workload=('dense_int8.cuda', ('TENSOR', (8, 2048), 'int8'), ('TENSOR', (1000, 2048), 'int8'), None, 'int32'). A fallback configurat$
on is used, which may bring great performance regression.

Performed inference in 17.40ms (std = 0.23) for 8 samples
Average per sample inference time: 2.17ms
zachzzc commented 3 years ago

It is missing tuning logs for the convolution workload so the convolution speed is not optimal. You can use the auto-tuning in TVM to search for the best convolution configuration. As you use the same hardware platform as me, I upload the tuning log to the repo. You can just run the inference again and it will apply the tuning automatically.

I also upload a input image for test purpose under tvm_benchmark/models .

haibao-yu commented 3 years ago

It is missing tuning logs for the convolution workload so the convolution speed is not optimal. You can use the auto-tuning in TVM to search for the best convolution configuration. As you use the same hardware platform as me, I upload the tuning log to the repo. You can just run the inference again and it will apply the tuning automatically.

I also upload a input image for test purpose under tvm_benchmark/models .

Thanks, I rerun the inference with the tunning log, followings are the results

INT4(without manual-code): 1.10ms
INT4(with manual-code): 0.75ms
INT8(without manual-code): 1.01ms
INT8(with manual-code): 1.02ms

Compared to INT8, INT4 has 25% speedup and 1.35x speedup. And is there any relationship between the tunning logs and the mixed-precision configuration? Thanks

haibao-yu commented 3 years ago

It is missing tuning logs for the convolution workload so the convolution speed is not optimal. You can use the auto-tuning in TVM to search for the best convolution configuration. As you use the same hardware platform as me, I upload the tuning log to the repo. You can just run the inference again and it will apply the tuning automatically. I also upload a input image for test purpose under tvm_benchmark/models .

Thanks, I rerun the inference with the tunning log, followings are the results

INT4(without manual-code): 1.10ms
INT4(with manual-code): 0.75ms
INT8(without manual-code): 1.01ms
INT8(with manual-code): 1.02ms

Compared to INT8, INT4 has 25% speedup and 1.35x speedup. And is there any relationship between the tunning logs and the mixed-precision configuration? Thanks

Sorry, I ignored the auto-tunning with --tuning-enable open. I'll try it

haibao-yu commented 3 years ago

It is missing tuning logs for the convolution workload so the convolution speed is not optimal. You can use the auto-tuning in TVM to search for the best convolution configuration. As you use the same hardware platform as me, I upload the tuning log to the repo. You can just run the inference again and it will apply the tuning automatically. I also upload a input image for test purpose under tvm_benchmark/models .

Thanks, I rerun the inference with the tunning log, followings are the results

INT4(without manual-code): 1.10ms
INT4(with manual-code): 0.75ms
INT8(without manual-code): 1.01ms
INT8(with manual-code): 1.02ms

Compared to INT8, INT4 has 25% speedup and 1.35x speedup. And is there any relationship between the tunning logs and the mixed-precision configuration? Thanks

Sorry, I ignored the auto-tunning with --tuning-enable open. I'll try it

Hi, with --tuning-enable and --manual-code, there is new error

[Task 16/21]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (50/50) | 132.35 s Done.
Task(func_name=conv2d_HWNCnc_tensorcore.cuda, args=(('TENSOR', (14, 14, 8, 1024), 'int8'), ('TENSOR', (1, 1, 512, 1024), 'int8'), (2, 2), (0, 0, 0, 0), (1, 1), 'int32'), kwargs={}, wor
kload=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (14, 14, 8, 1024), 'int8'), ('TENSOR', (1, 1, 512, 1024), 'int8'), (2, 2), (0, 0, 0, 0), (1, 1), 'int32'))
[Task 17/21]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (50/50) | 132.39 s Done.
Task(func_name=conv2d_HWNCnc_tensorcore.cuda, args=(('TENSOR', (7, 7, 8, 2048), 'int8'), ('TENSOR', (1, 1, 512, 2048), 'int8'), (1, 1), (0, 0, 0, 0), (1, 1), 'int32'), kwargs={}, workl
oad=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (7, 7, 8, 2048), 'int8'), ('TENSOR', (1, 1, 512, 2048), 'int8'), (1, 1), (0, 0, 0, 0), (1, 1), 'int32'))
[Task 18/21]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (50/50) | 122.41 s Done.
Task(func_name=conv2d_HWNCnc_tensorcore.cuda, args=(('TENSOR', (9, 9, 8, 512), 'int8'), ('TENSOR', (3, 3, 512, 512), 'int8'), (1, 1), (0, 0, 0, 0), (1, 1), 'int32'), kwargs={}, workloa
d=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (9, 9, 8, 512), 'int8'), ('TENSOR', (3, 3, 512, 512), 'int8'), (1, 1), (0, 0, 0, 0), (1, 1), 'int32'))
[Task 19/21]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (50/50) | 132.41 s Done.
Task(func_name=conv2d_HWNCnc_tensorcore.cuda, args=(('TENSOR', (7, 7, 8, 512), 'int8'), ('TENSOR', (1, 1, 2048, 512), 'int8'), (1, 1), (0, 0, 0, 0), (1, 1), 'int32'), kwargs={}, worklo
ad=('conv2d_HWNCnc_tensorcore.cuda', ('TENSOR', (7, 7, 8, 512), 'int8'), ('TENSOR', (1, 1, 2048, 512), 'int8'), (1, 1), (0, 0, 0, 0), (1, 1), 'int32'))
[Task 20/21]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (50/50) | 132.35 s Done.
Task(func_name=dense_int8.cuda, args=(('TENSOR', (8, 2048), 'int8'), ('TENSOR', (1000, 2048), 'int8'), None, 'int32'), kwargs={}, workload=('dense_int8.cuda', ('TENSOR', (8, 2048), 'i$
t8'), ('TENSOR', (1000, 2048), 'int8'), None, 'int32'))
[Task 21/21]  Current/Best:    0.00/   0.00 GFLOPS | Progress: (50/50) | 132.41 s Done.
Apply tuning log ./mixed_precision_models/tuning_logs/resnet50_HWNC_int4_batch_8.log
Traceback (most recent call last):

  File "test_resnet_inference_time.py", line 259, in <module>
    module.run()

  File "/home/yuhaibao94/tvm/python/tvm/contrib/graph_runtime.py", line 176, in run
    self._run()

  File "/home/yuhaibao94/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 219, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (3) /home/yuhaibao94/tvm/build/libtvm.so(TVMFuncCall+0x61) [0x7f78210a1601]
  [bt] (2) /home/yuhaibao94/tvm/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::detail::PackFuncVoidAddr_<4, tvm::runtim$
::CUDAWrappedFunc>(tvm::runtime::CUDAWrappedFunc, std::vector<tvm::runtime::detail::ArgConvertCode, std::allocator<tvm::runtime::detail::ArgConvertCode> > const&)::{lambda(tvm::runtim$
::TVMArgs, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0xbc) [0x7f782110b2cc]
  [bt] (1) /home/yuhaibao94/tvm/build/libtvm.so(tvm::runtime::CUDAWrappedFunc::operator()(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*, void**) const+0x8e7) [0x7f782110b037]
  [bt] (0) /home/yuhaibao94/tvm/build/libtvm.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x32) [0x7f78206faf52]
  File "/home/yuhaibao94/tvm/src/runtime/cuda/cuda_module.cc", line 117
  File "/home/yuhaibao94/tvm/src/runtime/library_module.cc", line 89
CUDAError: Check failed: ret == 0 (-1 vs. 0) : cuModuleGetFunction fused_nn_conv2d_add_nn_relu_cast_cast_left_shift_multiply_add_right_shift_cast_c_3441552496575213188__11_kernel0 fai$
ed with error: CUDA_ERROR_NOT_FOUND

terminate called after throwing an instance of 'dmlc::Error'
  what():  [11:13:02] /home/yuhaibao94/tvm/src/runtime/workspace_pool.cc:115: Check failed: allocated_.size() == 1 (2 vs. 1) : 
Stack trace:
  [bt] (0) /home/yuhaibao94/tvm/build/libtvm.so(tvm::runtime::WorkspacePool::Pool::Release(DLContext, tvm::runtime::DeviceAPI*)+0x652) [0x7f7821084302]
  [bt] (1) /home/yuhaibao94/tvm/build/libtvm.so(tvm::runtime::WorkspacePool::~WorkspacePool()+0x3f) [0x7f782108298f]
  [bt] (2) /home/yuhaibao94/anaconda3/bin/../lib/libstdc++.so.6(+0x9b64a) [0x7f781867d64a]
  [bt] (3) /lib/x86_64-linux-gnu/libc.so.6(+0x3a008) [0x7f7832ad1008]
  [bt] (4) /lib/x86_64-linux-gnu/libc.so.6(+0x3a055) [0x7f7832ad1055]
  [bt] (5) /lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xf7) [0x7f7832ab7847]
  [bt] (6) python(+0x1e3e32) [0x563cd77e4e32]

./run_resnet_inference_time_int8_int4.sh: line 4: 21045 Aborted                 (core dumped) python test_resnet_inference_time.py --bit-config $bit_config --num-layers $num_layers --m
anual-code --tuning-enable

Have you ever encountered this problem?

zachzzc commented 3 years ago

--manual-code This option is used to run the inference with modified CUDA file. TVM will generate CUDA code after compilation but we are seeing performance drawbacks on TVM cast calculation. So I manually modified the generated CUDA files and use it to do the inference. Don't specify --tuninig-enable and --manual-code together. If you want to do tuning, specify tuning-enable only and increase --tuning-trials to about 3000. The tuning log that I upload should be good enough since we run on the same GPUs.

haibao-yu commented 3 years ago

--manual-code This option is used to run the inference with modified CUDA file. TVM will generate CUDA code after compilation but we are seeing performance drawbacks on TVM cast calculation. So I manually modified the generated CUDA files and use it to do the inference. Don't specify --tuninig-enable and --manual-code together. If you want to do tuning, specify tuning-enable only and increase --tuning-trials to about 3000. The tuning log that I upload should be good enough since we run on the same GPUs.

Many Thanks. So what's the your inference time and speedup with INT4 compared to with INT8? Here, compared to INT8, INT4 has 25% speedup and 1.35x speedup.

haibao-yu commented 3 years ago

--manual-code This option is used to run the inference with modified CUDA file. TVM will generate CUDA code after compilation but we are seeing performance drawbacks on TVM cast calculation. So I manually modified the generated CUDA files and use it to do the inference. Don't specify --tuning-enable and --manual-code together. If you want to do tuning, specify tuning-enable only and increase --tuning-trials to about 3000. The tuning log that I upload should be good enough since we run on the same GPUs.

To increase --tuning-trials to about 3000, we should modify "tvm/python/tvm/autotvm/tuner/xgboostcost_model.py" as

    try:
        from xgboost.training import aggcv
    except ImportError:
        from xgboost.callback import _aggcv as aggcv

to prevent insufficient memory. More details could be seen https://zhuanlan.zhihu.com/p/340557261

zachzzc commented 3 years ago

--manual-code This option is used to run the inference with modified CUDA file. TVM will generate CUDA code after compilation but we are seeing performance drawbacks on TVM cast calculation. So I manually modified the generated CUDA files and use it to do the inference. Don't specify --tuninig-enable and --manual-code together. If you want to do tuning, specify tuning-enable only and increase --tuning-trials to about 3000. The tuning log that I upload should be good enough since we run on the same GPUs.

Many Thanks. So what's the your inference time and speedup with INT4 compared to with INT8? Here, compared to INT8, INT4 has 25% speedup and 1.35x speedup.

We got 1.45x speed up. The detailed result can be found https://github.com/Zhen-Dong/HAWQ/blob/main/model_zoo.md

haibao-yu commented 3 years ago

--manual-code This option is used to run the inference with modified CUDA file. TVM will generate CUDA code after compilation but we are seeing performance drawbacks on TVM cast calculation. So I manually modified the generated CUDA files and use it to do the inference. Don't specify --tuninig-enable and --manual-code together. If you want to do tuning, specify tuning-enable only and increase --tuning-trials to about 3000. The tuning log that I upload should be good enough since we run on the same GPUs.

Many Thanks. So what's the your inference time and speedup with INT4 compared to with INT8? Here, compared to INT8, INT4 has 25% speedup and 1.35x speedup.

We got 1.45x speed up. The detailed result can be found https://github.com/Zhen-Dong/HAWQ/blob/main/model_zoo.md

Thanks. I also implement the INT8 with TensorRT INT8 that cost 1.4 ms, and I think the inference speed of TVM is pretty good.