pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.25k stars 6.96k forks source link

[RFC] torchvision performance optimization on CPU #6619

Open mingfeima opened 2 years ago

mingfeima commented 2 years ago

🚀 The feature

This RFC is targeting at improving performance of operators from torchvision on CPU.

Motivation, pitch

Generally performance improvements can be made in 3 ways:

The plan is to cover both inference and training optimizations at the same time.

Affected Operators

The optimization scope will cover the native kernels from csrc/ops/cpu, including:

These operators will affect models such as FasterRCNN, MaskedRCNN, etc.

[Discussion Needed]: need to sort out the priorities of these kernels.

API and Behavior Change

Since all the optimizations will be done on the kernel level, no API change will be required.

Users will be able to run models in channels last as recommended from memory_format_tutorial:

### convert input and model from NCHW to NHWC
input = input.to(memory_format=torch.channels_last)
model = model.to(memory_format=torch.channels_last)

To run model in bfloat16 with explicit data type conversion or AMP:

### explicit data type conversion
input = input.to(dtype=torch.bfloat16)
model = model.to(dtype=torch.bfloat16)

### with AMP
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
    output = model(input)

Non-Batch Mode Input

Some models will have the input in non-batch mode e.g. CHW (N = 1), this can not be converted to channels last in torch at the moment:

### when input is 3-dimensional tensor, the following line will receive a runtime error:
input = input.to(memory_format=torch.channels_last)

torch.nn.conv2d will check the memory format of input and weight, if either one of them is channels last, the convolution wil use channels last path. Therefore, for non-batch mode input, we can only converting the model and still channels last will be used.

This part requires special attention and validation effort.

Parallelization on Multi Core CPUs

We propose to follow the identical parallelization scheme with torch, e.g. using the wrapper at::parallel_for. It can be linked to OpenMP or TBB depending on the build option (by default OpenMP will be used).

This commit is an example of paralleling roi_align on the 1st dimension of the input tensor, e.g. n_rois, with help of at::parallel_for.

 at::parallel_for(0, n_rois, 1, [&](int begin, int end) {
    for (int n = begin; n < end; n++) {
      int index_n = n * channels * pooled_width * pooled_height;

      const T* offset_rois = rois + n * 5;
      int roi_batch_ind = offset_rois[0];

      /* rest of the function is identical to original kernel*/

Vectorization on x86 CPUs

Vectorization can be done multiple ways, namely:

Auto Vectorization

Let compiler automatically vectorize with #pragma omp simd, this commit adds channels last support for roi_align and did vectorization on the last dimension, e.g. channels:

  for (int iy = 0; iy < roi_bin_grid_h; iy++) {
    for (int ix = 0; ix < roi_bin_grid_w; ix++) {
      detail::PreCalc<T> pc = pre_calc[pre_calc_index];
      const T* in1 = input + pc.pos1 * channels;
      const T* in2 = input + pc.pos2 * channels;
      const T* in3 = input + pc.pos3 * channels;
      const T* in4 = input + pc.pos4 * channels;

      #pragma omp simd
      for (int c = 0; c < channels; c++) {
        out[c] += pc.w1 * in1[c] + pc.w2 * in2[c] + pc.w3 * in3[c] + pc.w4 * in4[c];
      }
      pre_calc_index += 1;
    }
  }

Note that on NCHW, this kernel can not be vectorized.

Manual Vectorization

Vectorize the code via at::vec::Vectorized<> struct, which will be compiled to different assembly depending on arch, avx2/avx512 or neon.

  using Vec = at::vec::Vectorized<T>;
  for (int iy = 0; iy < roi_bin_grid_h; iy++) {
    for (int ix = 0; ix < roi_bin_grid_w; ix++) {
      detail::PreCalc<T> pc = pre_calc[pre_calc_index];
      const T* in1 = input + pc.pos1 * channels;
      const T* in2 = input + pc.pos2 * channels;
      const T* in3 = input + pc.pos3 * channels;
      const T* in4 = input + pc.pos4 * channels;

      int64_t d = 0;
      for (; d < channels - (channels % Vec::size()); d += Vec::size()) {
        Vec out_vec =
            Vec(pc.w1) * Vec::loadu(in1 + d) +
            Vec(pc.w2) * Vec::loadu(in2 + d) +
            Vec(pc.w3) * Vec::loadu(in3 + d) +
            Vec(pc.w4) * Vec::loadu(in4 + d);
        out_vec.store(out + d);
      }
      /* handle the remainder here ... */
      pre_calc_index += 1;
    }
  }

From performance point of view, these two approaches would have similar results.

[Discussion Needed]: need to decide which way to go.

Experiment Results

A demo shows performance improvement with channels last support on model fast_rcnn_R_50_FPN_1x from detectron2:

export DETECTRON2_DATASETS=../datasets
python benchmark.py --config-file ../configs/COCO-Detection/fast_rcnn_R_50_FPN_1x.yaml --task eval

torch: 1.13.0a0 torchvision: 0.14.0a0 detectron2: 0.6 cpu: Intel(R) Xeon(R) Gold 6248 CPU @ 2.50GHz

time of 300 iters (unit: s) NCHW (before) NCHW (after) NHWC (after) SpeedUp
single core (C=1) 638.21 639.01 503.04 126.87%
single socket (C=20) 212.10 141.06 102.54 206.84%

Breakdown

Here is performance breakdown of NCHW (before) v.s. NHWC (after):

We can see that the performance improvement primarily comes from:

Additional

[Discussion Needed]: need to decide details of performance benchmarking, such as:

[Discussion Needed]: test cases: we will add new test cases in corresponding modules from vision/test when making pull requests, what else is needed?

NicolasHug commented 2 years ago

Thanks a lot @mingfeima for this very well-put proposal. The benchmarks look promising!

Looking at the targeted operators, we typically use these in the model training stage on GPUs. Thus I assume that the main use-case for optimizing them on CPU would be for CPU inference? Would you have concrete examples where this is applicable?

As a side note, since we're talking about vectorization: I might start taking a look into making our Resize() / interpolate() transform faster (on tensors). Comparing ours with Pillow-SIMD, we're observing major improvements from vectorization. If this is something that can be of interest to you, I'm more than happy to chat more!

vfdev-5 commented 2 years ago

As a side note, since we're talking about vectorization: I might start taking a look into making our Resize() / interpolate() transform faster (on tensors).

@NicolasHug FYI, interpolation is already vectorized for 2d case by mingfeima : https://github.com/pytorch/pytorch/blob/bd854588fb927371c319d24d31b659731eddc3bc/aten/src/ATen/native/cpu/UpSampleKernel.cpp#L442-L602

However, we can benefit from the vectorization (according to the current implementation) only for inputs with >=4 channels (@mingfeima please correct me if I'm wrong):

https://github.com/pytorch/pytorch/blob/bd854588fb927371c319d24d31b659731eddc3bc/aten/src/ATen/native/cpu/UpSampleKernel.cpp#L509-L514

IMO, the main needs in resize optimization is native support for uint8 without copying data to float and back.

mingfeima commented 2 years ago

@NicolasHug First of all, yes our priority is inference. And the most requested model from our customers are MaskedRCNN and its variants. So from this point of view, the key bottleneck operator would be RoiAlign forward path.

Anyway we would centainly like to hear more inputs from you guys, what other models/operators might be interested, so as to sort out the priorities among the TODOs.

Meanwhile, we would also like to contribute to backwards (this is more from our internal KPI pressure not business requirements).

@vfdev-5 Talking about resize or interpolate, the first factor is the memory format, usually we can only do vectorization on NHWC (NCHW can be vectorized on some specific case, such as scale=2; but generically NCHW will use scalar logic).

Secondly, as you have pointed out, only when C > Vec::size() will the code be vectorized. And Vec::size() will be 8 for float under avx2 and 16 under avx512, and so on. This is because current impl for vectorization with remainder requires a memcpy (instead of masked load) so it's not that efficient. Interpolation on unit8 should be done on acc type (float32) but this doesn't mean it should be slow, we can do inplace dtype conversion and the whole process can be vectorized.

Anyway. do you have any minimal example/benchmark to reproduce resize performance? I can give it a try to see how to improve it.

vfdev-5 commented 2 years ago

@mingfeima thanks for your answer about resize. Maybe we can continue discussion in another issue related to interpolation. There are few of them, e.g. https://github.com/pytorch/vision/issues/6465 (image is read in 3d HWC format but once unsqueezed it was not recognized as 1HWC channel last and thus resize is going as channels first fallback, very slow)

As for NCHW, I agree with what you say. In our previous approach we did implicit compiler vectorization which was done on reccurent ops like out += w * src and some others.

Anyway, here is a gist to produce a benchmark pth vs pil: https://gist.github.com/vfdev-5/7885ee41d31789cd159dcd52b2e8fc6a

We would like optimize cases like:

mingfeima commented 2 years ago

@NicolasHug @vfdev-5 Oh sorry for the late response, super busy recently, just got time to take a look at this last weekend ...

I opened https://github.com/pytorch/pytorch/pull/87053 to address mode=bilinear (3, H, W) on float, shall we move the discussion ?

NicolasHug commented 2 years ago

Thanks @mingfeima , I'll take a look

I just want to note that this part has been addressed in https://github.com/pytorch/pytorch/pull/86361, so there's no need to focus on it anymore

mode=nearest for (1, H, W) uint8, where IMO there is "bug" that implementation goes to your channels last route and it is slow but if it were going to channels first implementation it could be faster.

zhiqwang commented 2 years ago

Hopefully there will be support for uint8 type input and an accelerated version of it for interpolate() as mentioned in https://github.com/pytorch/pytorch/pull/86361#issuecomment-1269822386 and https://github.com/pytorch/pytorch/issues/5580 .

mingfeima commented 2 years ago

Hopefully there will be support for uint8 type input and an accelerated version of it for interpolate() as mentioned in pytorch/pytorch#86361 (comment) and pytorch/pytorch#5580 .

sum up the status a little bit:

NicolasHug commented 2 years ago

Just FYI, I started working on support for uint8, mode=bilinear, antialias=True, channels_last, shape=(1,3,H,W) in https://github.com/pytorch/pytorch/pull/87863

FrancescoSaverioZuppichini commented 1 year ago

Hi, any update on this?

mingfeima commented 1 year ago

Hi, any update on this?

NicolasHug and vfdev-5 have done a lot of job in optimizing int8/uint8 image scaling/resize on torch.