apache / mxnet

Lightweight, Portable, Flexible Distributed/Mobile Deep Learning with Dynamic, Mutation-aware Dataflow Dep Scheduler; for Python, R, Julia, Scala, Go, Javascript and more
https://mxnet.apache.org
Apache License 2.0
20.76k stars 6.8k forks source link

mx.io.ImageRecordIter does not respect dtype argument / FP16 performance on Volta #9774

Open rahul003 opened 6 years ago

rahul003 commented 6 years ago

Description

mx.io.ImageRecordIter or src/io/iter_image_recordio_2.cc doesn't respect dtype parameter taken. It is designed to only work with float32 because of instantiating the class with real_t dtype. (in src/io/iter_image_recordio_2.cc). Can we make it handle fp16 too? This is important for fp16 training.

Also, training in fp16 seems slower than fp32 for some models.

Environment info (Required)

Mxnet 1.0 Package used: Python

Error Message:

Silently generates fp32 data

Minimum reproducible example

N/A

Steps to reproduce

N/A

What have you tried to solve it?

Can we come up with a better way than to create a new operator passing DType as fp16?

@ptrendx

ptrendx commented 6 years ago

I don't think it will perform better than producing fp32 and then casting to fp16 at the beginning of the training. 1) You need this double-buffering of data at the beginning of training in order to hide cpu->gpu transfers 2) You would have fp16 computation done on cpu inside the data iterator then. They are slow, due to cpu not having native support for fp16

rahul003 commented 6 years ago

I see. Could you expand on 1 please?

ptrendx commented 6 years ago

Engine does not seem to differentiate between first layer and subsequent layers on that it considers data going into first layer as being modified by the backward pass of the network (even though it does not actually happen). This means that copying of the next batch has to wait to the end of backward, which effectively exposes the copy. Having this double-buffering scheme of either cast fp32->fp16 or just identity fp32->fp32 makes sure that the ndarray used to copy the next batch to the gpu is returned from engine before backward pass ends, which enables copy to happen while backward computation takes place.

rahul003 commented 6 years ago

Thanks for the explanation.

Btw, is training in fp16 supposed to be ~2x faster than fp32 for a given batch size? Or is this only about reduced memory usage so we can use larger batch sizes. I have run the examples you had added for fp16 in image classification, and I see maybe +- 10-15% speed changes, nowhere close to 2x for a given batch size. Is this normal?

I ran these tests on p3.8x and p3.16x EC2 machines, which use the Volta range of GPUs. I have CUDA9 as well.

ptrendx commented 6 years ago

There are few possible explanations. The most probable reason is workspace size for convolutions. I tried pitching @piiswrong to change the default MXNet's behavior of limiting the results of cudnnFind to the ones fitting the workspace, but did not have luck with that. Try with MXNET_CUDNN_AUTOTUNE_DEFAULT = 2. Also if you tried benchmarking with real data, make sure you are not limited by the IO (you may need to set --data-nthreads to something more than the default 4). And finally, depthwise convolutions in networks like resnext do not currently benefit much from TensorCore, so if that is what you tested, then benefit should be small.

rahul003 commented 6 years ago

Thanks, this blogpost mentions training imagenet with resnet50 on 8 volta gpus. Could you share the performance benefits you had observed in such settings? Or under optimal conditions, roughly how much speedup did you notice with MXNet.

KellenSunderland commented 6 years ago

@rahul003 If you're only seeing ~15% speedups I'd recommend you run nvprof before your training. Take a look at the GEMMS and ensure they have s884 in the name. If they don't then one of these rules is probably not being followed:

A Few Simple Rules

cuBLAS users will notice a few changes from their existing cuBLAS GEMM code:

    The routine must be a GEMM; currently, only GEMMs support Tensor Core execution.
    The math mode must be set to CUBLAS_TENSOR_OP_MATH. Floating point math is not associative, so the results of the Tensor Core math routines are not quite bit-equivalent to the results of the analogous non-Tensor Core math routines.  cuBLAS requires the user to “opt in” to the use of tensor cores.
    All of k, lda, ldb, and ldc must be a multiple of eight; m must be a multiple of four. The Tensor Core math routines stride through input data in steps of eight values, so the dimensions of the matrices must be multiples of eight.
    The input and output data types for the matrices must be either half precision or single precision. (Only CUDA_R_16F is shown above, but CUDA_R_32F also is supported.)

(from https://devblogs.nvidia.com/programming-tensor-cores-cuda-9/) GEMMs that do not satisfy the above rules will fall back to a non-Tensor Core implementation.

The tensor cores are a little tricky to use in a lot of cases. Let us know if nvprof shows that your model isn't being run on tensor cores and we might be able to give you some next steps.

rahul003 commented 6 years ago

Hey @KellenSunderland

I ran Resnet50 with Imagenet and got about 70% speedup. Some of the top ones don't seem to have s884 but some operations do. Can I improve the speed further? Here's a log of the profiler

 GPU activities:    7.58%  179.746s    965213  186.22us  1.3440us  30.955ms  [CUDA memcpy HtoD]
                    6.60%  156.366s   1547648  101.03us  2.0480us  711.93us  void nchwToNhwcKernel<__half, __half, float, bool=1>(int, int, int, int, __half const *, __half*, float, float)
                    5.84%  138.416s    108321  1.2778ms  314.97us  2.4579ms  volta_fp16_scudnn_fp16_128x64_relu_interior_nn_v1
                    5.02%  119.036s    336336  353.92us  88.063us  1.2733ms  void cudnn::detail::bn_bw_1C11_singleread_fp16<int=512, int=1, int=2, int=14>(float, float, float, float, cudnnTensorStruct, __half2 const *, cudnn::detail::bn_bw_1C11_singleread_fp16<int=512, int=1, int=2, int=14>, __half2 const , cudnn::detail::bn_bw_1C11_singleread_fp16<int=512, int=1, int=2, int=14>, cudnnTensorStruct*, float const *, float*, float const *, float const , float const , float, cudnn::reduced_divisor, int, float*, cudnn::detail::bnBwPersistentState*, int, float, float, float, int, float, cudnnStatus_t*, bool)
                    4.83%  114.400s    400400  285.71us  33.088us  1.2040ms  void cudnn::detail::activation_bw_4d_kernel<__half, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, __half const *, __half const , cudnn::detail::activation_bw_4d_kernel<__half, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, __half const , cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*)
                    4.72%  111.789s     40190  2.7815ms  1.4254ms  8.5978ms  void cudnn::detail::wgrad_alg0_engine<__half, int=128, int=6, int=8, int=3, int=3, int=5, bool=1, int=512>(int, int, int, __half const *, int, cudnn::detail::wgrad_alg0_engine<__half, int=128, int=6, int=8, int=3, int=3, int=5, bool=1, int=512>*, __half const , kernel_grad_params, int, float, int, int, int, int)
                    4.71%  111.533s     16040  6.9534ms  4.7594ms  10.580ms  void cudnn::detail::dgrad2d_alg1_1<__half, int=0, int=6, int=7, int=5, int=4, int=5, bool=1, bool=1>(int, int, int, __half const *, int, __half const , int, cudnn::detail::dgrad2d_alg1_1<__half, int=0, int=6, int=7, int=5, int=4, int=5, bool=1, bool=1>*, kernel_grad_params, int, int, float, int, int)
                    4.14%  98.2065s    304388  322.64us  198.69us  723.45us  volta_s884cudnn_fp16_128x128_ldg8_wgrad_exp_interior_nhwc_nt_v1
                    3.60%  85.2418s     56116  1.5190ms  625.92us  2.4792ms  volta_fp16_scudnn_fp16_128x128_stridedB_interior_nn_v1
                    3.54%  83.8923s    336336  249.43us  65.888us  908.99us  void cudnn::detail::bn_fw_tr_1C11_singleread_fp16<int=512, int=1, int=2, int=20>(cudnnTensorStruct, __half2 const *, cudnn::detail::bn_fw_tr_1C11_singleread_fp16<int=512, int=1, int=2, int=20>, cudnnTensorStruct*, float const *, float const , float, float, float*, float const *, float const *, float const *, float, float, cudnn::reduced_divisor, int, float*, cudnn::detail::bnFwPersistentState*, int, float, float, float, int, float, float, cudnnStatus_t*, bool)
                    3.36%  79.5726s     40168  1.9810ms  801.53us  10.105ms  void cudnn::detail::dgrad_engine<__half, int=128, int=6, int=7, int=3, int=3, int=5, bool=1>(int, int, int, __half const *, int, __half const , int, cudnn::detail::dgrad_engine<__half, int=128, int=6, int=7, int=3, int=3, int=5, bool=1>*, kernel_grad_params, int, int, float, int, int, int)
                    2.96%  70.1196s    416400  168.39us  22.303us  684.35us  void cudnn::detail::activation_fw_4d_kernel<__half, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, __half const *, cudnn::detail::activation_fw_4d_kernel<__half, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*)
                    2.89%  68.4710s     16082  4.2576ms  1.6871ms  6.6937ms  void cudnn::detail::dgrad_engine<__half, int=512, int=6, int=5, int=3, int=3, int=3, bool=1>(int, int, int, __half const *, int, __half const , int, cudnn::detail::dgrad_engine<__half, int=512, int=6, int=5, int=3, int=3, int=3, bool=1>*, kernel_grad_params, int, int, float, int, int, int)
                    2.86%  67.7144s    174945  387.06us  290.78us  902.65us  volta_fp16_s884cudnn_fp16_256x128_ldg8_relu_f2f_exp_interior_nhwc2nchw_tn_v1
                    2.81%  66.5956s      8008  8.3161ms  8.2356ms  8.5289ms  void cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1>(float, cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1>, cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1>, cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1>, cudnnTensorStruct, __half const *, float, __half const , float, cudnnTensorStruct*, cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1> const *, cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1>*, cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1> const *, cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1> const , cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1> const , cudnn::detail::bn_bw_1C11_kernel_new<__half, float, float2, int=512, bool=1, int=1>)
                    2.71%  64.1674s     16164  3.9698ms  1.2205ms  4.2252ms  volta_fp16_scudnn_fp16_128x128_stridedB_splitK_interior_nn_v1
                    2.64%  62.5345s      8008  7.8090ms  7.7552ms  8.5208ms  void cudnn::detail::bn_fw_tr_1C11_kernel_new<__half, float, int=512, bool=1, int=1>(cudnnTensorStruct, __half const *, cudnn::detail::bn_fw_tr_1C11_kernel_new<__half, float, int=512, bool=1, int=1>, cudnnTensorStruct*, float const *, float const , cudnnTensorStruct*, cudnnTensorStruct*, cudnnTensorStruct**, float const *, float const *, float const *, cudnnTensorStruct*, cudnnTensorStruct*)
                    2.28%  53.9698s    954403  56.548us  1.2160us  256.37ms  [CUDA memcpy DtoH]
                    2.25%  53.3050s    133248  400.04us  115.33us  954.87us  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS0_10mshadow_op4plusELi1EEEJPN7mshadow4half7half2_tESA_SA_EEEviDpT0_
                    2.21%  52.3835s    128128  408.84us  114.27us  953.12us  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_3SumEJPN7mshadow4half7half2_tENS_9OpReqTypeES7_S7_EEEviDpT0_
                    2.09%  49.4748s     99967  494.91us  437.95us  612.13us  volta_fp16_s884cudnn_fp16_128x128_ldg8_relu_f2f_exp_small_nhwc2nchw_tn_v1
                    1.91%  45.1522s      8016  5.6328ms  5.5531ms  5.9249ms  void cudnn::detail::dgrad2d_alg1_1<__half, int=0, int=6, int=6, int=5, int=4, int=4, bool=1, bool=1>(int, int, int, __half const *, int, __half const , int, cudnn::detail::dgrad2d_alg1_1<__half, int=0, int=6, int=6, int=5, int=4, int=4, bool=1, bool=1>*, kernel_grad_params, int, int, float, int, int)
                    1.75%  41.4783s      8064  5.1436ms  856.09us  7.4089ms  void cudnn::detail::wgrad_alg0_engine<__half, int=512, int=6, int=5, int=3, int=3, int=3, bool=1, int=512>(int, int, int, __half const *, int, cudnn::detail::wgrad_alg0_engine<__half, int=512, int=6, int=5, int=3, int=3, int=3, bool=1, int=512>*, __half const , kernel_grad_params, int, float, int, int, int, int)
                    1.69%  40.0064s    112152  356.72us  277.31us  864.38us  volta_fp16_s884cudnn_fp16_256x128_ldg8_dgrad_f2f_exp_interior_nhwc2nchw_tt_v1
                    1.43%  33.8762s     48075  704.65us  353.79us  870.81us  volta_s884cudnn_fp16_64x256_sliced1x4_ldg8_wgrad_exp_interior_nhwc_nt_v1
                    1.37%  32.5869s      8092  4.0271ms  1.8701ms  12.452ms  void cudnn::detail::dgrad_engine<__half, int=128, int=6, int=8, int=3, int=3, int=5, bool=1>(int, int, int, __half const *, int, __half const , int, cudnn::detail::dgrad_engine<__half, int=128, int=6, int=8, int=3, int=3, int=5, bool=1>*, kernel_grad_params, int, int, float, int, int, int)
                    1.31%  31.0537s     64080  484.61us  441.79us  619.93us  volta_fp16_s884cudnn_fp16_256x128_ldg8_dgrad_f2f_exp_small_nhwc2nchw_tt_v1
                    1.23%  29.2180s      8052  3.6287ms  1.4834ms  3.8291ms  void cudnn::detail::wgrad_alg0_engine<__half, int=128, int=6, int=7, int=3, int=3, int=5, bool=1, int=512>(int, int, int, __half const *, int, cudnn::detail::wgrad_alg0_engine<__half, int=128, int=6, int=7, int=3, int=3, int=5, bool=1, int=512>*, __half const , kernel_grad_params, int, float, int, int, int, int)
                    1.16%  27.5703s    112686  244.66us  1.3750us  749.50us  void scalePackedTensor_kernel<__half, float>(cudnnTensor4dStruct, __half*, float)
                    1.02%  24.0620s      8008  3.0047ms  2.8931ms  3.2269ms  void cudnn::detail::pooling_bw_kernel_max<__half, float, cudnn::detail::maxpooling_func<float, cudnnNanPropagation_t=0>, bool=0>(cudnnTensorStruct, __half const *, cudnn::detail::pooling_bw_kernel_max<__half, float, cudnn::detail::maxpooling_func<float, cudnnNanPropagation_t=0>, bool=0>, __half const , cudnn::detail::pooling_bw_kernel_max<__half, float, cudnn::detail::maxpooling_func<float, cudnnNanPropagation_t=0>, bool=0>, __half const , cudnn::detail::pooling_bw_kernel_max<__half, float, cudnn::detail::maxpooling_func<float, cudnnNanPropagation_t=0>, bool=0>, cudnnTensorStruct*, cudnnPoolingStruct, float, cudnnPoolingStruct, int, cudnn::reduced_divisor, float)
rahul003 commented 6 years ago

But for Resnet 110 on Cifar10, fp16 is much slower. Do you see something fishy here? There are barely any operations with s884 in their names. All the top ones don't. So fp16 would not help us with small networks/models?

fp16
 GPU activities:   64.80%  28.7596s     87602  328.30us  24.128us  444.10us  void cudnn::detail::wgrad_alg0_engine<__half, int=512, int=6, int=5, int=3, int=3, int=3, bool=1, int=512>(int, int, int, __half const *, int, cudnn::detail::wgrad_alg0_engine<__half, int=512, int=6, int=5, int=3, int=3, int=3, bool=1, int=512>*, __half const , kernel_grad_params, int, float, int, int, int, int)
                   11.44%  5.07525s    120474  42.127us  27.776us  70.752us  void cudnn::winograd::winograd3x3Kernel<__half, float, int=4, int=1, int=8, bool=0>(cudnn::maxwell::winograd::KernelParams)
                    9.65%  4.28153s     61959  69.102us  66.112us  88.128us  void cudnn::winograd::winograd3x3Kernel<__half, float, int=2, int=2, int=8, bool=0>(cudnn::maxwell::winograd::KernelParams)
                    1.90%  845.06ms    182434  4.6320us  3.4240us  12.256us  void cudnn::winograd::generateWinogradTilesKernel<int=0, __half, float>(cudnn::winograd::GenerateWinogradTilesParams<__half, float>)
                    1.40%  621.33ms     29716  20.909us  19.744us  26.400us  void cudnn::detail::bn_bw_1C11_singleread_fp16<int=512, int=1, int=2, int=14>(float, float, float, float, cudnnTensorStruct, __half2 const *, cudnn::detail::bn_bw_1C11_singleread_fp16<int=512, int=1, int=2, int=14>, __half2 const , cudnn::detail::bn_bw_1C11_singleread_fp16<int=512, int=1, int=2, int=14>, cudnnTensorStruct*, float const *, float*, float const *, float const , float const , float, cudnn::reduced_divisor, int, float*, cudnn::detail::bnBwPersistentState*, int, float, float, float, int, float, cudnnStatus_t*, bool)
                    1.39%  615.34ms     85238  7.2190us  3.8080us  11.776us  void cudnn::detail::activation_bw_4d_kernel<__half, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, __half const *, __half const , cudnn::detail::activation_bw_4d_kernel<__half, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, __half const , cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*)
                    1.27%  564.01ms     29716  18.980us  11.648us  22.080us  void cudnn::detail::bn_fw_tr_1C11_singleread_fp16<int=512, int=1, int=2, int=20>(cudnnTensorStruct, __half2 const *, cudnn::detail::bn_fw_tr_1C11_singleread_fp16<int=512, int=1, int=2, int=20>, cudnnTensorStruct*, float const *, float const , float, float, float*, float const *, float const *, float const *, float, float, cudnn::reduced_divisor, int, float*, cudnn::detail::bnFwPersistentState*, int, float, float, float, int, float, float, cudnnStatus_t*, bool)
                    1.01%  446.51ms    102351  4.3620us  3.1360us  11.296us  void cudnn::detail::activation_fw_4d_kernel<__half, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, __half const *, cudnn::detail::activation_fw_4d_kernel<__half, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*)

API calls:   37.66%  44.8500s   1042145  43.036us  5.0250us  2.6695ms  cudaStreamSynchronize
                   25.68%  30.5900s   1391180  21.988us  6.6070us  12.689ms  cudaLaunch
                   11.41%  13.5923s    266335  51.034us  7.7360us  9.7191ms  cudaMemcpy2DAsync
fp32
GPU activities:   29.29%  6.95644s     87602  79.409us  19.072us  133.28us  void cudnn::detail::wgrad_alg0_engine<float, int=512, int=6, int=5, int=3, int=3, int=3, bool=1, int=512>(int, int, int, float const *, int, cudnn::detail::wgrad_alg0_engine<float, int=512, int=6, int=5, int=3, int=3, int=3, bool=1, int=512>*, float const , kernel_grad_params, int, float, int, int, int, int)
                   16.85%  4.00136s     87609  45.672us  36.960us  71.200us  void cudnn::winograd::winograd3x3Kernel<float, float, int=4, int=1, int=8, bool=0>(cudnn::maxwell::winograd::KernelParams)
                    9.11%  2.16456s     28155  76.879us  72.737us  92.384us  void cudnn::winograd::winograd3x3Kernel<float, float, int=2, int=2, int=8, bool=0>(cudnn::maxwell::winograd::KernelParams)
                    8.96%  2.12819s     33807  62.951us  22.560us  65.729us  volta_scudnn_128x32_relu_small_nn_v1
                    4.66%  1.10676s     86020  12.866us  7.4880us  16.224us  void cudnn::detail::bn_fw_tr_1C11_singleread<float, int=512, bool=1, int=1, int=2, int=0>(cudnnTensorStruct, float const *, cudnn::detail::bn_fw_tr_1C11_singleread<float, int=512, bool=1, int=1, int=2, int=0>, cudnnTensorStruct*, float const *, float const , float, float, float*, float const *, float const *, float const *, float, float, cudnn::reduced_divisor, int, float*, cudnn::detail::bnFwPersistentState*, int, float, float, float, int, float, float, cudnnStatus_t*, bool)
                    4.43%  1.05193s     86020  12.228us  7.5520us  24.448us  void cudnn::detail::bn_bw_1C11_singleread<float, int=512, bool=1, int=1, int=2, int=0>(float, float, float, float, cudnnTensorStruct, float const *, cudnn::detail::bn_bw_1C11_singleread<float, int=512, bool=1, int=1, int=2, int=0>, float const , cudnn::detail::bn_bw_1C11_singleread<float, int=512, bool=1, int=1, int=2, int=0>, cudnnTensorStruct*, float const *, float*, float const *, float const , float const , float, cudnn::reduced_divisor, int, float*, cudnn::detail::bnBwPersistentState*, int, float, float, float, int, float, cudnnStatus_t*, bool)
                    3.24%  769.36ms     85238  9.0260us  4.0960us  16.224us  void cudnn::detail::activation_bw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, float const *, float const , cudnn::detail::activation_bw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, float const , cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*)
                    2.84%  675.48ms    102351  6.5990us  3.9360us  13.344us  void cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>(cudnnTensorStruct, float const *, cudnn::detail::activation_fw_4d_kernel<float, float, int=128, int=1, int=4, cudnn::detail::relu_func<float, cudnnNanPropagation_t=0, bool=0>>, cudnnTensorStruct*, float, cudnnTensorStruct*, int, cudnnTensorStruct*)
                    2.80%  665.55ms     32869  20.248us  19.680us  39.808us  volta_sgemm_128x64_nn
                    2.66%  631.06ms    116704  5.4070us  4.6400us  22.592us  void cudnn::winograd::generateWinogradTilesKernel<int=0, float, float>(cudnn::winograd::GenerateWinogradTilesParams<float, float>)
                    2.12%  503.18ms    261188  1.9260us  1.6640us  14.976us  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_12SGDMomKernelEJPfS4_S4_S4_fffffNS_9OpReqTypeEEEEviDpT0_
                    1.56%  370.84ms    262127  1.4140us  1.2800us  12.384us  [CUDA memcpy DtoD]
                    1.41%  335.58ms     32873  10.208us  8.2240us  18.688us  void cudnn::winograd_nonfused::winogradForwardOutput4x4<float, float>(cudnn::winograd_nonfused::WinogradOutputParams<float, float>)
                    1.30%  308.23ms     32873  9.3760us  8.5760us  23.360us  void cudnn::winograd_nonfused::winogradForwardData4x4<float, float>(cudnn::winograd_nonfused::WinogradDataParams<float, float>)
                    1.22%  289.39ms    203358  1.4230us  1.3120us  12.224us  [CUDA memset]
                    1.21%  286.42ms     42228  6.7820us  3.5200us  13.312us  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS0_3SumEJPfNS_9OpReqTypeES4_S4_EEEviDpT0_
                    1.19%  283.47ms     50706  5.5900us  2.9440us  12.352us  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS0_10mshadow_op4plusELi1EEEJPfS7_S7_EEEviDpT0_
                    1.03%  244.04ms    172374  1.4150us  1.1840us  23.425us  _ZN5mxnet2op8mxnet_op20mxnet_generic_kernelINS1_11op_with_reqINS1_10set_to_intILi0EEELi1EEEJPfEEEviDpT0_

API calls:   30.01%  23.7375s   1042031  22.780us  5.0120us  2.8708ms  cudaStreamSynchronize
                   26.13%  20.6708s   1452568  14.230us  5.9830us  17.019ms  cudaLaunch
                   12.79%  10.1212s    267274  37.868us  8.3050us  10.654ms  cudaMemcpy2DAsync
ptrendx commented 6 years ago

@rahul003 Could you paste here how you invoked the benchmark script? Did you set the autotune env var to 2? Also, which cudnn version do you use? You should see many more s884 kernels in your profile... CIFAR has very small images and most of the time the dimensions of convolutions are not suitable for TensorCore.

rahul003 commented 6 years ago

I'm running this command.

python train_cifar10.py --batch-size 256 --network resnet --num-layers 50 --gpus 0,1,2,3,4,5,6,7

Are you sure you have the right variable? MXNET_CUDNN_AUTOTUNE_DEFAULT=2 doesn't seem to be a valid value. It still tries to autotune if I set it to 2

Relevant code is

if (param.cudnn_tune.value() && reg_.size() % 50 == 0) {
      LOG(INFO) << "Running performance tests to find the best convolution "
                   "algorithm, "
ptrendx commented 6 years ago

I was asking about the imagenet script. If you use smaller batch size like 256 for 8 GPUs (the best results you will see with larger batch size like 1024 for 8 GPUs), you may consider turning on the NCCL kvstore (--kv-store nccl in the train_imagenet.py script). Yes, 2 is a valid (albeit undocumented :-() value - for both 1 and 2 it performs cudnnFind call, but the default value of 1 makes it reject the algos that require more workspace than a threshold (for convolutions default threshold is 1 GB). Setting it to 2 makes it always choose the fastest algo:

    DMLC_DECLARE_FIELD(cudnn_tune)
    .add_enum("off", conv::kOff)
    .add_enum("limited_workspace", conv::kLimited)
    .add_enum("fastest", conv::kFastest)
    .set_default(dmlc::optional<int>())
        .describe("Whether to pick convolution algo by running performance test.");
rahul003 commented 6 years ago

Okay cool, I'll try to document that.

I was using

python train_imagenet.py --data-train data/imagenet1k-train.rec --data-val data/imagenet1k-val.rec --gpus 0,1,2,3,4,5,6,7 --dtype float16 --num-epochs 1 --batch-size 1280 --data-nthreads 24
rahul003 commented 6 years ago

I have cudnn v7005 and cuda 9.0

ptrendx commented 6 years ago

Just in case try synthetic data with --benchmark 1 - with 24 threads I bet you are still limited by the IO. Try also adding index file to speed up the IO with the --data-train-idx parameter (although depending on whether your data is cached in RAM or not, you may need to run it twice to get the actual speedup from this as it introduces global random shuffling).

rahul003 commented 6 years ago

Both suggestions didn't help improve the speed unforunately. Using MXNET_CUDNN_AUTOTUNE_DEFAULT=2 helped in some cases. But we can't say this setting helps consistently. If it picks the fastest, why would it not help in all cases? I understand cases where it should be same speed as other algos. But sometimes, this is slower than setting it to 1. All else should remain same, right?

rahul003 commented 6 years ago

Sorry I was digressing from the topic of the issue. Regarding the iterator issue, we need to document the behavior that it will return fp32 data regardless. Keeping this open till we fix it or document it