microsoft / CNTK

Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit
https://docs.microsoft.com/cognitive-toolkit/
Other
17.5k stars 4.29k forks source link

CNTK 2.4: Why is convolution operation so much slower? #2966

Open MatthiasRock opened 6 years ago

MatthiasRock commented 6 years ago

Hi,

I've noticed that my BrainScript network trains much slower with version 2.4 than with version 2.3. In CNTK 2.3 it trains with 25.4 samples/s and in CNTK 2.4 only with 11.7 samples/s. In Python it is slow in both versions.

I found out that the problem is caused by the convolutions with a constant 45x45 kernel. When I remove these convolutions, the training is even faster with version 2.4 (37.5 samples/s) than with 2.3 (35.6 samples/s) .

So why is the convolution operation so much faster in BrainScript with CNTK 2.3 than in the other cases?

ke1337 commented 6 years ago

Are you using GPU or CPU? Can you provide more details?

ke1337 commented 6 years ago

BTW below is my benchmarking script for the scenario you described. I tried on both GPU (Titan XP) and CPU (Intel Xeno E5-2690) and didn't find big perf difference:

import cntk as C
import numpy as np
import time
import tqdm
iter = 1000
#C.try_set_default_device(C.cpu()); iter = 10
x = C.input_variable((3,224,224))
y = C.reduce_sum(C.reduce_sum(C.layers.Convolution((45,45))(x)), C.Axis.default_batch_axis())
data = C.Value.create(x, np.random.rand(1,3,224,224).astype(np.float32))

# warm up before timing using 1/10th iter
for i in range(iter // 10):
    y.eval(data)

start = time.time()
acc = 0
for i in tqdm.tqdm(range(iter)):
    acc += y.eval(data)
print(time.time() - start)
MatthiasRock commented 6 years ago

When I run your script on my GeForce GTX 1080 Ti with 10000 iterations:

On my CPU (i7-7700K) there is no performance difference between the CNTK versions.

MatthiasRock commented 6 years ago

I don't know if the following information helps: benchmark

ke1337 commented 6 years ago

You are right. I made a mistake in my test settings. It does seem to be a big perf regression between 2.3 and 2.4, which seems to be related to different convolution algo selection. Here's the top 5 entries in nvprof in 2.3 for the script above:

Time(%)      Time     Calls       Avg       Min       Max  Name
 48.25%  2.61736s     11000  237.94us  228.30us  315.34us  void fermiPlusCgemmLDS128_batched<bool=1, bool=0, bool=0, bool=0, int=4, int=4, int=4, int=3, int=3, bool=1, bool=0>(float2**, float2**, float2**, float2*, float2 const *, float2 const *, int, int, int, int, int, int, __int64, __int64, __int64, float2 const *, float2 const *, float2, float2, int)
  9.08%  492.24ms     11000  44.749us  44.162us  57.859us  compute_gemm_pointers(float2**, float2 const *, int, float2 const *, int, float2 const *, int, int)
  8.37%  454.19ms     33000  13.763us  7.8400us  23.105us  void transpose_readWrite_alignment_kernel<float2, float2, int=1, bool=0, int=6, int=4, int=4>(cublasTransposeParams<float2>, float2 const *, float2*, float2 const *)
  8.24%  446.82ms     11000  40.619us  39.362us  52.675us  void zeroPad256<float, float, bool=0>(float const *, float*, int, int, int, int, int, int, int, int)
  7.71%  418.24ms     11000  38.021us  37.345us  48.994us  void zeroPad256<float, float, bool=1>(float const *, float*, int, int, int, int, int, int, int, int)

Comparing to 2.4:

            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   93.20%  26.9577s     11000  2.4507ms  2.4238ms  3.2737ms  maxwell_cgemm_64x64_tn
                    1.68%  487.26ms     11000  44.296us  43.745us  57.635us  compute_gemm_pointers(float2**, float2 const *, int, float2 const *, int, float2 const *, int, int)
                    1.56%  450.56ms     33000  13.653us  7.6480us  22.849us  void transpose_readWrite_alignment_kernel<float2, float2, int=1, bool=0, int=6, int=4, int=4>(cublasTransposeParams<float2>, float2 const *, float2*, float2 const *)
                    1.06%  307.63ms     22000  13.983us  10.880us  19.457us  void DSE::regular_fft_pad<int=0, int=1, int=256, int=16, int=16, int=1, float, float, float2>(float2*, float*, int, int3, float*, int, float*, float*, int, int, int, int, int, bool)
                    0.82%  236.38ms     11000  21.489us  21.025us  25.921us  void DSE::regular_fft_clip<int=1, int=2, int=256, int=16, int=16, int=1, float, float, float2>(float*, float2*, int, int3, float2*, int, float2*, float2*, int, int, int, int, int, float, float, bool, int, float, float)

It seems that 2.4 chooses a FFT-based algo while 2.3 chooses GEMM-based algo. I think it's time to consult NVidia gurus. @FDecaYed can you comment on the algo differences here?

rhy-ama commented 6 years ago

@KeDengMS maybe makes sense to setup regression tests as part of the build process (vs release points) ?

FDecaYed commented 6 years ago

@GhMaRo it is likely cublas/cudnn, I will try to figure out what is the exact cause. Meanwhile, could you provide a little more information on the size you are running? At least I need batchsize, input size, output channel count, stride and pad information.

MatthiasRock commented 6 years ago

I use a constant kernel with size [1 x 1 x 45 x 45]. In principle there are many of the following convolutions:

out = C.convolution(kernel, in) # [1 x 96 x 96]

So default stride, pad, etc.

But you can also test the performance difference very good with the script posted by KeDengMS.

FDecaYed commented 6 years ago

@GhMaRo Thanks for the reply, that's useful information. I am very curious about what kind of task you are doing, may I ask? Such large filter size is very uncommon to us, so it may not get well tuned. If there are solid use cases, we would love to know from our user, so we can keep improve our libraries. We may also try to figure out if there is way to work around this, since library release cycle is in term of months. @KeDengMS Both 2.3 and 2.4 is using fft, which is correct given filter size is so large. The problem seems not in CNTK, but the different choice of kernel to do cgemm. The new kernel picked maybe faster in larger sizes, but slower in this case(batched small cgemms). I will work with our library team to see what we can do.

MatthiasRock commented 6 years ago

@FDecaYed We are using such large filter sizes for facial landmark detection via a fully-convolutional network. For this we need to have several convolutions with a gaussian-like 45x45-kernel. Our paper will be published on CVPR this year. So in a few weeks we will also publish our network and all the other code. It would be great if there is a way to use CNTK 2.4 with at least the same speed as 2.3. So at the moment, unfortunately, I have to use 2.3 for training.

FDecaYed commented 6 years ago

@GhMaRo Thanks. I'll definitely take a look at the code/paper when it got published.