oneapi-src / oneDNN

oneAPI Deep Neural Network Library (oneDNN)
https://uxlfoundation.org
Apache License 2.0
3.63k stars 1k forks source link

help wanted for deconv #568

Closed lzhengchun closed 5 years ago

lzhengchun commented 5 years ago

Hi,

I used deconv with 2x2 kernel and 2x2 stride for a 2x upscaling. I also created a wrapper function to build deconv layers (because my model has several of thos)

For whatever reasons, I could not get the same results as tensorflow.

my code, input and weight for your testing are available here. The sum of output should be 62098350.0(TF gives me), but my DNNL code gave me quite different value.

I have spent more than a day on this code. help wanted

alokbakshi commented 5 years ago

Hi lzhengchun,

If possible, could you please provide the header "tomogan.hpp" file, which is included in your code. In particular, I'm having trouble with "DNNL_ARG_WEIGHTS_CUS" variable while compiling your test file on my system.

Thanks, Alok

lzhengchun commented 5 years ago

@alokbakshi Thanks for addressing my issue.

Sorry for the missing, I just added the heater.

Thanks

alokbakshi commented 5 years ago

Hi lzhengchun, no worries! :-) The program compiles fine now and below is the output I got

8388608 bytes of input data have been successfully read !!!!! src mem needs reorder for deconv !!!!! weight mem needs reorder for deconv Input: 1 x 128 x 128 x 128 => Output: 1 x 128 x 128 x 128 Input: 1 x 128 x 128 x 128 => Output: 1 x 128 x 128 x 128 Input: 128 x 128 x 2 x 2 => Output: 128 x 128 x 2 x 2 Input: 1 x 128 x 128 x 128 => Output: 1 x 128 x 256 x 256 Input: 1 x 128 x 256 x 256 => Output: 1 x 128 x 256 x 256 load Weights: 128 x 128 x 2 x 2 65664 weights loaded! results checksum: 71149860.499734

(Please let me know if I've wrongly stated the issue) So the tensor-flow provides results checksum value as 62098350.0, which is different from the output above.

lzhengchun commented 5 years ago

@alokbakshi I got the same as you got. you understood correctly

checksum is a quick way to check but not sufficient (in case of wrong output layout). I also tried to load the results to compare wit tensorflow output

np.fromfile("output_img.bin", dtype=np.float32, count=-1, sep='').reshape((128, 128, 128))

it seems that my code gives quite difficult results as Tensorflow. I also verified weights that were loaded into DNNL, it seems correct as well.

BTW, my model was trained using Tensorflow

alokbakshi commented 5 years ago

[Update] Hi lzhengchun, thanks for the detailed info!

I made couple of (hopefully) innocent changes in your code, namely:

  1. Initialize deconv_src_md, deconv_weights_md, deconv_dst_md with the memory_tages nchw, oihw, and nchw respectively (rather than any as in the original code) .
  2. deconv_node_idx = 1 rather than 3 (because there won't be any src/wei/dst reorder now as reference deconvolution primitive will get executed)
  3. Changes with printf argument (%ld <--- %d )

But I still get the same outcome, namely 71149860.504502. So the jit_avx2 implementation (along with the src/wei/dst reorders) and reference implementation (which works with plain layout with no reorders) do agree on the checksum value.

I guess, we need to dig more to unearth the issue.

In the meantime, if possible could you please provide the tensorflow output file (for out_buffer variable) as well because then I can directly compare the respective out_buffer array values and also see when and where they differ from each other.

PS: For reference, modified test file is attached. The modified cpp file is attached with log extension as the forum is not allowing me to attach otherwise cpu_cnn_deconv_test.log

lzhengchun commented 5 years ago

@alokbakshi Thanks very much! yes if you explicitly set tag (instead of using any), reorder is not needed and the deconv_node_idx needs to be 1 rather than 3. it may lose some performance but I guess it is fine for now to debug

I uploaded the python script for Tensorflow computing

alokbakshi commented 5 years ago

@lzhengchun Yes, I was just trying to see if there was any issue with the reorders. Apparently it is not :-( so probably we need to dig more!!

Thanks a lot for providing the python script. Will run it and (hopefully) get to the bottom of issue

lzhengchun commented 5 years ago

Thanks very much for your help! @alokbakshi FYI, the code I gave to, has relu disabled (line 132), when I calculate checksum (line 198)I ignored all neg values to mimc relu (I used that the see if my relu has issue)

You can enable relu (line 132) and remove line 198 which should give you the same results. I just updated the python script a little bit to load DNNL output and compare with Tensorflow pixel by pixel that may helpful for you.

Thanks

vpirogov commented 5 years ago

I checked that there's no difference in how padding is done between TF and your code and that the shape in question produces correct results from benchdnn perspective. With that there are two possible sources of discrepancy:

  1. TF does different computations for transposed convolution
  2. There's a bug in the app code or validation code

@lzhengchun, could you please collect MKLDNN_VERBOSE for TF to see what exactly TF does?

lzhengchun commented 5 years ago

@vpirogov Here you are the verbose

mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_hwio out:f32_oihw,num:1,128x128x2x2,0.167969
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nhwc out:f32_nchw,num:1,1x128x128x128,8.30908
mkldnn_verbose,exec,convolution,gemm:blas,backward_data,fsrc:nchw fwei:oihw fbia:undef fdst:nchw,alg:convolution_direct,mb1_g1ic128oc128_ih256oh128kh2sh2dh0ph0_iw256ow128kw2sw2dw0pw0,53.998
mkldnn_verbose,exec,reorder,jit:uni,undef,in:f32_nchw out:f32_nhwc,num:1,1x128x256x256,25.3718
mkldnn_verbose,exec,eltwise,jit:avx2,forward_training,fdata:blocked fdiff:undef,alg:eltwise_relu,mb1ic256ih256iw128,2.28296

You can also try it using the python script

vpirogov commented 5 years ago

Thanks. Looks like the shape is exactly the same. Let's see what @alokbakshi digs out.

alokbakshi commented 5 years ago

@lzhengchun and @vpirogov

Sorry for the delayed response!

Till now, I verified that the bias vector to either of dnnl and tf output tensors is same.

On the other hand by manually printing the tf_output[0,0,:] and dnnl_output[0,0,:] tensors (since both these tensors are in nhwc format so the values obtained are 2D tensors with dimensions corresponding to output_width and output_channel). I observed that (reference: python script file attached at the end) tf_output[ow = 0, oc = 0] = 0 exactly while dnnl_output[ow = 0, oc = 0] \= 0.

This seems strange specially because bias vector has all non-zero components (guess, there is some issue with tf code). Anyways I'm still investigating it and will let you know as soon as I find the cause.

tf-deconv-cmp-dnnl python script file

lzhengchun commented 5 years ago

Thanks very much @alokbakshi for your help! weights of that given deconv layer are extracted from a deep learning model with 16 conv2d layers.

The weights I provided is the 13th (index 12) layer's (the first deconv layer). I have compared output of the 12th layer with my DNNL based implementation, they match, the problem comes from the 1st deconv layer.

In case it can help you. I also uploaded the entire keras model as well as the model input that will generate the intermediate output I provide for input (of 13th layer).

Thanks again! let me know if you need anything more.

lzhengchun commented 5 years ago

@alokbakshi @vpirogov I dig more the tensorflow conv2d_transpose. it says that "This operation is sometimes called "deconvolution" after Deconvolutional Networks, but is actually the transpose (gradient) of conv2d rather than an actual deconvolution."

Also, when I looked into MKLDNN_VERBOSE output (pasted above), it shows that the primitive used was convolution (not deconv). I guess that makes the difference?

As I said in the other issue I basically need upsampling but DNNL does not have such primitive, that's why I turn to use deconv.

So, could you provide an example to achieve 2x upsampling (either nearest or bilinear) as tesorflow's

Or, an example to implement TF's conv2d_transpose

I only need them for inference only.

Thanks again!

alokbakshi commented 5 years ago

[Update] Hi lzhengchun, For the model you have, we have the following parameter values: (KH = KW = 2, SH = SW = 2, N = 1, IH = IW = 128, PH = PW = 0, OH = OW = 256, OC = IC = 128)

With these values (and the fact that stride size is two) one gets the following equation:

output[n = 0, oc = 0, oh = 0, ow = 0] = bias[oc = 0] + \sum_{ic=0}^{127} input[n = 0, ic, ih = 0, iw = 0] * kernel[oc = 0, ic, kh = 0, kw = 0]

I see that the dnnl_output and manual calculation above gives same result (even while using either of formats as done in either of lines 36-38 of attached file).

But for some reason tensorflow output is different (it is negative here!). So currently, I am looking at the TF documentation to see if some parameter values are off or missing.

Please let me know if you see anything wrong with calculation! For reference purpose I have attached the modified python file. I have removed relu layer from the model, done the calculation as written above and have added/removed few print statements.

Python file with extension of log

lzhengchun commented 5 years ago

@alokbakshi Thanks for the update.

Do you see my previous post?

I believe both my code and TF are correct. the difference is because they actually did different things, TF simply computes gradient to input and name it as conv2d_transpose, and DNNL does actual deconv.

Could you provide examples / instructions as I mentioned above?

Thanks very much !

alokbakshi commented 5 years ago

Hi @lzhengchun,

Yes, I am just looking at now. Will get back to you soon :-)

alokbakshi commented 5 years ago

Hi lzhengchun,

The way I understand is that the deconvolution from neural network perspective is not inverse to convolution operator (as described here ) but rather is transposed convolution (as replied in comments here -- in particular please see the nice pictorial replies written by David and Andrei)

IN DNNL too, deconvolution primitive is implemented as a backward convolution with respect to the data. So I guess your verbose output

mkldnn_verbose,exec,convolution,gemm:blas,backward_data,fsrc:nchw fwei:oihw fbia:undef fdst:nchw,alg:convolution_direct,mb1_g1ic128oc128_ih256oh128kh2sh2dh0ph0_iw256ow128kw2sw2dw0pw0,53.998

is expected.

alokbakshi commented 5 years ago

Hi @lzhengchun,

Please let me know at what rate you want to do the up-sampling. As far as I understand if the rate is integer then we can use backward_data convolution as a proxy for up-sampling with same integral stride size!

Edit: Sorry I just missed this part in your post. It is 2x sampling. I guess then the use transposed convolution (with stride=2) makes perfect sense. I am just looking at the tensorflow documentation currently to see the reason behind difference in DNNL and TF outputs.

lzhengchun commented 5 years ago

Yes, I am also trying to use convolution_backward_data but haven’t succeeded yet. it would be really helpful if you can provide an example, I guess it will be useful to the community as well.

lzhengchun commented 5 years ago

@alokbakshi based on the latest reply by @vpirogov in issue #193, it seems that DNNL's deconv actually is transpose conv, so it should give the same result as TF's

lzhengchun commented 5 years ago

Hello @alokbakshi and @vpirogov problem solved.

TF implanted their own conv2d_transpose using convolution_backward_data, it is the same as DNN:'s deconv. But TF flipped I and O for the weight layout. i.e., TF's deconv has HWO while conv actually has HWIO if we really consider I is input channel and O denotes output channel. So, for TF's convolution_backward_data, I and O flipped (because of backward). since my I and O are the same, it is hard to realize.

Thanks very much for your effort! I still really hope DNNL can provide integer upsample scale as TF's, that saves computing in some cases like mine can same about 20%.

vpirogov commented 5 years ago

@lzhengchun, I'm glad it's resolved. Upsampling definitely makes sense and we will consider introducing a primitive with this functionality.