microsoft / CNTK

Microsoft Cognitive Toolkit (CNTK), an open source deep-learning toolkit
https://docs.microsoft.com/cognitive-toolkit/
Other
17.52k stars 4.28k forks source link

Small discrepancy in batch evaluating multiple vs single inputs #2962

Closed JanKrivanek closed 6 years ago

JanKrivanek commented 6 years ago

Can you point us to native implementation of entry point for CSharp_CNTK_FunctionEvaluateSWIG_0? Is it expected that individual outputs of this function differs based on number of inputs (column dimension of matrix) passed in? We observe small discrepancies (e.g. on 4-5th decimal place)

More background:

We have found a discrepancy in evaluation of CNTK models in C#. Concretely, we have a multi-layer perceptron with

Hidden layer weights:

-0.9684187 0.2783848 -0.01441747 -0.9967358 -0.4482535 0.0593956 0.09004943 -0.4658263

Hidden layer biases:

-1.018078 0.2626715 0.8714445 0.03156471

Output layer weights:

-0.5403279 -0.405864 -0.7699566 -0.3085891 -0.2580264 0.5474245 0.8033602 0.09686555 -0.4840894 0.3834926 -0.6715707 -0.565696

Output layer biases:

-0.2563959 -1.19075 0.4835828

The hidden activation function is CNTKLib.ReLU, the output is linear (that is, no non-linearity applied).

The problem:

We evaluate the network in two ways:

  1. Given two input vectors at once: (1, 2), (2, 3) in a minibatch. We prepare the input for CNTK using

Value.CreateBatch(new[] {InputDim}, inputVectors, _device)

The output node 0 is 0.342737764f

  1. Given just one input vector: (1, 2). We prepare the input for CNTK using

Value.CreateBatch(new[] { InputDim }, inputVector, _device)

The output node 0 is 0.342737794f

These are apparently unequal.

The problem grows with the size of the network. The problem seems to be more probable to reproduce with smaller amount of input rows for use case 1. The problem manifests also for other combinations of unequal number of inputs (not just for single input as shown in use case 2)

We don't know how to create deterministic repro code of the problem since we have not found out how to inject exact numbers into trainable parameters in C#. Otherwise, we would have sent a complete reproduction of the error.

ke1337 commented 6 years ago

Tried the repro in CPU device and the numbers are matching. However it's a bit different on GPU. Here's the code:

import cntk as C
import numpy as np

C.try_set_default_device(C.cpu()) # comment it to use GPU if possible

x = C.input_variable((2,))
h_W = C.constant(np.reshape(np.asarray([-0.9684187, 0.2783848, -0.01441747, -0.9967358, -0.4482535, 0.0593956, 0.09004943, -0.4658263]).astype(np.float32), (2,4)))
h_B = C.constant(np.asarray([-1.018078, 0.2626715, 0.8714445, 0.03156471]).astype(np.float32))

y = C.relu(x @ h_W + h_B)
o_W = C.constant(np.reshape(np.asarray([-0.5403279, -0.405864, -0.7699566, -0.3085891, -0.2580264, 0.5474245, 0.8033602, 0.09686555, -0.4840894, 0.3834926, -0.6715707, -0.565696]).astype(np.float32), (4,3)))
o_B = C.constant(np.asarray([-0.2563959, -1.19075, 0.4835828]).astype(np.float32))
out = y @ o_W + o_B

print(out.eval([[1,2]]))
print(out.eval([[1,2],[2,3]]))

I ran the numbers in this online float to hex converter, and found that 0.342737764 corresponds to hex 0x3eaf7b53, while 0.342737794 is hex 0x3eaf7b54. The difference is in the least significant bit of mantissa. @FDecaYed can you comment on the GPU precision for this issue?

JanKrivanek commented 6 years ago

@KeDengMS Thanks for looking into this.

This was reported by my colleague that is using GPU training/eval - he also provided the sample net that is supposed to reproduce the behavior - so it might be special case that is observable only on GPU.

However I'm absolutely positive that same phenomenon is observable also on CPU device - since I was able to reproduce on my dev mechine where I don't have GPU available.

ke1337 commented 6 years ago

Please construct a repro on CPU as above code. Or even better, please use C.debugging.set_computation_network_trace_level(1000000) to dump each node value to pin down the operator that caused the difference.

JanKrivanek commented 6 years ago

@KeDengMS It might be possible that GPU eval is more prone to this issue. I also wasn't able to repro with this small net (2-4-3) on CPU, but was able to repro by adding and/or increasing hidden layer(s) - e.g. 2-28-3 already quite consistently reproes.

I create small repro in C# (it creates a sample network and runs an eval on it and throws if there is difference in results): EvalDiscrepancyRepro.cs.txt

I'll try to setup python environemnt and create repro similar to yours once I have little more time. But I guess all that's needed is increase the the number of of connections in middle layer (e.g. as mentioned 2-28-3), while using arbitrary weights. CPU might need more calculations in order for the inprecisions to sum up. The question is why the float calculations inprecisions express differently for different number of inputs.

Thanks Jan

FDecaYed commented 6 years ago

@KeDengMS There are reasons that could cause this difference, including but not limited to internal register bits difference between x86 and gpu(could cause rounding difference), order of accumulation and compiler optimization. In my opinion, difference on the last bit is probably not a real bug. But there is no perfect formula to determine whether there is a problem very easily. To figure out what's going on for sure it possible but not trivial. The real question here is, what accuracy the user @jakrivan really need, and how we can achieve that. Also there seems to be more problem than CPU/GPU difference in this case.

JanKrivanek commented 6 years ago

@FDecaYed In an ideal world we get the identicall results when performing batched evaluation (which is more performant and so used in learning environment) and when performing evaluation on individual vectors one by one (which is what happens in productions as we evaluate data as they come) - so that we have 100% reproducibility between learning and production usage.

Btw. this is not about difference of GPU vs CPU evaluation - this is a difference of evalution of identical vector in same setup (just different batch size)

I understand that any floating point calculation is not precise; but I still don't understand why the evaluation result of the identical vector V1 depends on the number of other vectors in the same evaluation batch. Shouldn't those calculations be completely independent?

FDecaYed commented 6 years ago

Here is one thing different from batch/not batched. when you have one vector, you are doing essentially vector-matrix multiply(gemv) on first layer (2)x(2,4) = (4) When you have a batch of vector, you are doing matrix-matrix multiply(gemm) on first layer (2,2)x(2,4) = (2,4) they are different code

JanKrivanek commented 6 years ago

From my very original question:

The problem manifests also for other combinations of unequal number of inputs (not just for single input as shown in use case 2)

Is the underlying gemm implementation really expected to provide different results for a same sub-matrix?

ke1337 commented 6 years ago

Mathematically, equations like (a + b) * c == a * c + b * c holds, but there are many details in hardware execution that may cause the rounding of LSB to be different. Even for gemm, the implementation might be different based on size of the input. Intel used to have 80-bit registers to fight this problem, but it comes with a cost in computation complexity. In your case, you may use double precision to evaluate your network and round it to single floating point, if precision is more important than performance.

JanKrivanek commented 6 years ago

It's rather that reproducibility is more important than precision. We are postprocessing the output activations anyway, so we can first cut of some of the less significant bits (rather than have those change the decision in some rare cases). This would compensate the imprecision (for production<>learning environemnt reproducibility), while still should not influence the abilities of the neural net.

Is there any expectation to what degree should be the results same (the number of bits in 23bit mantisa of a 32bit float expected to be same)? We can simply try by some trial-error experiments (plus then buffer the finding a bit) - but should there be any existing knowledge in this tipic - we'd like to follow it.

Thanks for all the time detailed information so far! Jan

ke1337 commented 6 years ago

IMO The principles are: 1) for the same input data with no randomization, all outputs should be identical including LSB. 2) for quasi-equivalent data (like in your case), outputs are expected to be identical but for LSB 3) With randomized initialization, trained model might be different, but the best results should be close to paper (<0.5% in test error)

JanKrivanek commented 6 years ago

@KeDengMS Unfortunately it seems that any number of LSB bits cutting would mask the differece.

I wrote a small code that compares the results of evaluation of different batch size and counts number of identical bits starting from MSB and I encountered cases even with just 3 common MSBs (which is 1 bit sign and 2 bits of exponent - rest of exponent and mantis was off). Expectedly - this was the case of very small numbers, where absolute difference was very small.

Example of comparison (of activations for same input vectors but in different input batch size):

Common bits: 32; f1: 0.9999887; f2: 0.9999887. absolute difference: 0, relative difference: 0 Common bits: 25; f1: 1.389807E-13; f2: 1.389811E-13. absolute difference: 4.06575814682064E-19, relative difference: 2.92540851867259E-06 Common bits: 28; f1: 1.133249E-05; f2: 1.133249E-05. absolute difference: 2.72848410531878E-12, relative difference: 2.40766446274412E-07 Common bits: 9; f1: 0.9999998; f2: 1. absolute difference: 2.38418579101563E-07, relative difference: 2.38418607523275E-07 Common bits: 3; f1: 2.453833E-19; f2: 4.062482E-25. absolute difference: 2.45382896069308E-19, relative difference: 1.99999325806454 Common bits: 6; f1: 2.290642E-07; f2: 4.819688E-09. absolute difference: 2.24244473656654E-07, relative difference: 1.91757134718733

Conclusion: It doesn't sem there would be a way to make results same by adjusting the resolution of floating number (the same phenomenon would cause differences even when using doubles for eval and then casting to floats). For now we use some small allowed absolute difference in comparison test. To prevent cases where a small difference changes a meaning of result more sophisticated output interpreter should be used - e.g. simple argmax would be suspectible to problems