official-stockfish / nnue-pytorch

Stockfish NNUE (Chess evaluation) trainer in Pytorch
GNU General Public License v3.0
327 stars 100 forks source link

Idea for reducing quantization error #47

Closed ddobbelaere closed 3 years ago

ddobbelaere commented 3 years ago

This evening I had some fun improving the visualizer. I'll try to create a PR if I find the time for it.

One of the things I investigated are the weights of the fully-connected layers after the input layer. Also, as all hidden neurons are "arbitrarily ordered" (a random permutation leads to an equivalent net), I ordered the input weights by their L1-norm (sum of absolute values).

Here are the plots for master net nn-62ef826d1a6d.nnue (an offspring from sergioviero's run) and nn-0f63c1539914.nnue (latest vdv net on fishtest).

master master_fc master_hist vdv vdv_fc vdv_hist

Here's the main "observation": the (quantized) FC1 weights are sharply peaked (even more so for vdv net than for master) around zero. This is not the case for FC2 weights (not shown here, but already "visually apparent" from other figure). I suspect that the quantization error leads to some ELO loss (how much is the question :)).

Would it be a good idea to rescale the input weights such that the FC1 weights become higher in magnitude, such that their quantization error decreases? Denote the output of neuron j at layer i as $x^{(i)}_j$, the bias term $b_j$ and its input weights $a_k$. With a non-linear activation function $sigma$, this leads to the familiar equation:

image

If sigma is a clamped ReLU and q_j>0, it holds that

image

if and only if sigma "is not clamping" (to the max. saturation value).

Maybe it's an idea to multiply the input weights and bias of neuron j with a factor q_j<1 and at the same time divide the FC1 weights connected with this neuron with q_j. The resulting unquantized net performance will be equivalent if no neuron clamping/saturation occurs, but the quantized net might have beter performance.

Some clarifications:

So, in summary, I argue for a dynamic rescaling of both the input and FC1 weights such that the relative quantization error of the latter weights dramatically decreases (actually, the total relative quantization errors, also taking into account input weights, might be even better). This rescaling is "exact" (leads to an equivalent unquantized net) if no input layer neuron clamping occurs during play. It is expected that quantized net performance will be (much?) better in that case.

vondele commented 3 years ago

concerning the quantization error.... experiments with decreasing the error are not so easy usually, but experiments increasing the error should be rather easy (on serialize, introduce more error). We could look at the impact of that Elo-wise... if it is small probably quantization error is not so important.

I'm currently thinking that the clipping could actually be worse, but that's just gut feeling, no data.

ddobbelaere commented 3 years ago

Good points. One argument against it though is that some FC1 weights (connected to single input neuron, so pair of columns in graph) might already "so much drowned in noise" that increasing quantization noise leads to more or less same performance, but rescaling them "reveals the signal".

But then again, this claim can be easily verified (how high is the SNR in dB for each connected neuron?) during serialization.

ddobbelaere commented 3 years ago

@vondele Could you maybe share one (or a few?) checkpoints of converged nets to do some investigation on?

I would like to verify some of the claims I've made above (w.r.t. relative quantization error).

vondele commented 3 years ago

@ddobbelaere I've uploaded three of my best nets&checkpoints so far https://drive.google.com/file/d/1WmdW5ubcJ5j5TbfUhNgDktpnwz3WArm2/view?usp=sharing their Elo performance is also documented https://docs.google.com/document/d/1UJe9dT8YAz-Z5sGWD2IwFZHD1F0EL6zjNoeS0gaYpBE/edit#bookmark=id.ias9exeptn5d

Let me know if I can help with running any experiments (e.g. a different export of the net, ideally not requiring a different stockfish player binary)

vondele commented 3 years ago

OK, I've created a branch for testing quantization error increase... if you agree with rounding to even is equivalent to introducing additional quantization. https://github.com/vondele/Stockfish/commit/0c8ab4e69c4b4cd36ef0613538ea8553975c4171 https://github.com/vondele/Stockfish/commits/quantize

The results are quite interesting:

   # PLAYER          :  RATING  ERROR   POINTS  PLAYED   (%)  CFS(%)
   1 quantizeNone    :     0.0   ----  38113.5   74492    51     100
   2 quantizeFT      :    -3.3    2.1  18469.0   37292    50     100
   3 quantizeAll     :   -13.0    1.9  17909.5   37200    48     ---

which suggests that this additional quantization could be about 13 Elo in performance, and if we just apply it to the feature transformer it is still about 3 Elo. That's definitely a scale we should care about.

glinscott commented 3 years ago

Very interesting. Going to the full pytorch quantization is going to be a ton of work, and require rewriting the c++ inference code. I'm poking at how hard it would be to simulate the quantization errors during training, pytorch does this using https://pytorch.org/docs/stable/_modules/torch/quantization/fake_quantize.html#FakeQuantize. We might be able to repurpose this.

ddobbelaere commented 3 years ago

@vondele Very interesting experiment. I agree rounding to even is equivalent to some form of additional quantization error. What you can definitely conclude from the experiment is that the play performance of the quantized net is very sensitive to least-significant-bit changes of the weights and biases, because this is what is happening (you effectively set the LSB of all weights/biases to zero in your experiment).

One relatively non-intrusive idea I have to potentially greatly increase quantized net accuracy, with minimal extra inference time, is to include an integer shift factor for each neuron (totalling to exactly 512 + 32 + 32 + 1 integer shifts per inference, which might be negligible w.r.t. the bulk of operations), so that we can maximize the dynamic range of the input weights (and bias) connected to this neuron. [Although this mechanism will not solve "underflow" at the neuron output]

vondele commented 3 years ago

One more data-point... I was expecting the last layer (with outputDimension==1) to be most important in this context, as these weights seemed most important in tuning, but that appears not to be the case:

   # PLAYER                  :  RATING  ERROR   POINTS  PLAYED   (%)  CFS(%)
   1 quantizeNone            :     0.0   ----  52279.0  101768    51     100
   2 quantizeFT              :    -3.3    2.0  18614.0   37583    50     100
   3 quantizeAll             :   -13.1    1.9  18003.0   37400    48      65
   4 quantizeAllButOutput    :   -13.6    2.2  12872.0   26785    48     ---
ddobbelaere commented 3 years ago

@vondele Really interesting, what about the isolated effect of extra quantization noise on the L1 layer neurons (the 32x512 weights for which the sharply peaked histograms are shown)? Or maybe just quantizeAllButL1 to have the opposite.

vondele commented 3 years ago

@ddobbelaere so L1 is important:

   # PLAYER                  :  RATING  ERROR   POINTS  PLAYED   (%)  CFS(%)
   1 quantizeNone            :     0.0   ----  79609.5  154824    51     100
   2 quantizeFT              :    -3.3    1.8  18614.0   37583    50      98
   3 quantizeAllButL1        :    -6.4    2.3  13288.5   27074    49     100
   4 quantizeAll             :   -13.1    2.0  18003.0   37400    48      82
   5 quantizeAllButOutput    :   -14.3    1.7  25309.0   52767    48     ---
ddobbelaere commented 3 years ago

@vondele Hmm, yeah, seems like it (I had a vague suspicion). As mentioned earlier, the performance loss caused by additional L1 quantization error you are measuring (about 7 elo?) might even be an underestimation of loss from the current quantization, as the connection between some "hidden input features" (of which there are 256) and the L1 layer (this corresponds to one column of the plots in #48) might already be drowned in noise (only composed of a few quantization bins around zero, leading to high relative errors).

I hope to investigate this claim soon (based on your checkpoints).

ddobbelaere commented 3 years ago

I can confirm the above mentioned effect after analysis on checkpoint epoch=245.ckpt.

Denote an "input neuron pair" as the pair of (own, opponent) neurons from the input layer corresponding to the same features. Its L1 weights (for all L1 neurons), i.e. its only connection with the rest of the net, corresponds to a column in the figures in #48. Define the relative quantization error as the error between the L1 weights and their quantized counterpart. Relevant diff w.r.t. #48:

+++ b/visualize.py
@@ -134,6 +134,33 @@ class NNUEVisualizer():
             l1_weights[2*i+1] = l1_weights_[i][self.M +
                                                self.ordered_input_neurons]

+        rel_error = np.zeros(self.M)
+        max_s = np.zeros(self.M)
+        min_s = np.zeros(self.M)
+        for i in range(self.M):
+            s = l1_weights[:, i]
+            sq = np.round(64*l1_weights[:, i])/64
+            min_s[i] = np.min(s)
+            max_s[i] = np.max(s)
+            rel_error[i] = np.sum(np.abs(s-sq)**2)/np.sum(np.abs(s)**2)
+
+        sorted_indices = np.argsort(rel_error)
+
+        plt.figure()
+        plt.plot(10*np.log10(rel_error[sorted_indices]))
+        plt.xlabel("Input neuron pair")
+        plt.ylabel("dB")
+        plt.title(
+            "Relative quantization error of L1 weights [{}]".format(net_name))
+
+        plt.figure()
+        plt.plot(max_s[sorted_indices], label='max')
+        plt.plot(min_s[sorted_indices], label='min')
+        plt.xlabel("Input neuron pair")
+        plt.legend()
+        plt.title(
+            "Max/min of L1 weights [{}]".format(net_name))
+
         if vmin >= 0:

Sorted w.r.t. increasing quantization error (in dB), we have the following plots:

error_dB min_max

Conclusions:

vondele commented 3 years ago

Since we know that validation error correlates with Elo (https://docs.google.com/document/d/1UJe9dT8YAz-Z5sGWD2IwFZHD1F0EL6zjNoeS0gaYpBE/edit#bookmark=id.mkupyubn5uer) (like 1e-3 change is 20 Elo) I wonder if we could look at the effect of quantization on validation error. Maybe the graph is leveling off because (see google docs) because validation error after quantization stops going down?

ddobbelaere commented 3 years ago

@vondele Interesting idea. I've observed (visually) that all three checkpoints are very very similar. If the weights/biases change only very little (e.g. low LR), they might get mapped onto the same quantized net, or one close to it plus "quantization noise". But this can be easily quantified/checked of course.

vondele commented 3 years ago

visually the nets take their final 'look' quite early on. I would say, have their 'signature' maybe after 10-20 epochs. Good performance comes later. Those were indeed very late epochs, so the learning rate was pretty low (we drop every 75 epochs).

From the (SPSA/nevergrad) tuning, I recall that weights are just changed very little (maybe 1-2 units) yet this had quite some Elo impact.

vondele commented 3 years ago

I'm copying here a comment from Gary from discord:


% python cross_check_eval.py --net ../NNUE/nets/vondele/nn-epoch427.nnue --engine ../Stockfish/src/stockfish --data d8_128000.bin --features=HalfKP^ --checkpoint ../NNUE/nets/vondele/epoch=427.ckpt         
Min engine/model eval: -515 / -498.88386726379395
Max engine/model eval: 443 / 465.64918756484985
Avg engine/model eval: -19.73 / -7.815914869308472
Avg abs engine/model eval: 184.05 / 183.7648463845253
Relative engine error: 0.1886259862671465
Relative model error: 0.1986029047235611
Avg abs difference: 19.232174019813538
Sopel97 commented 3 years ago

To add to the previous comment. This is an example output for a net early in the training:

Min engine/model eval: -530 / -533.0950498580933
Max engine/model eval: 549 / 550.3568887710571
Avg engine/model eval: -32.98 / -33.84214849770069
Avg abs engine/model eval: 298.82 / 300.4185470491648
Relative engine error: 0.024919539834772474
Relative model error: 0.02349334676848089
Avg abs difference: 2.306553307175636

Hence I think the majority of the error in the vondele's net is from clipped weights.

ddobbelaere commented 3 years ago

Hmm, I see, nice work @glinscott :)

Seems difficult to prevent weights clipping. Only thing I can think of is L2-regularization loss term (for example only for weights that are int8 after quantization), but not too strong to prevent "underflow" and too high quantization error.

EDIT: or maybe keep everything as is, but enforce weight min/max constraints during training (update steps). See e.g. https://discuss.pytorch.org/t/restrict-range-of-variable-during-gradient-descent/1933

ddobbelaere commented 3 years ago

So, I've done some "surgery" to avoid the clipping of the 7 weights in L1 during serialization of epoch=245.cpkt.

The surgery is done as detailed above (in the opening post). The code can be found in https://github.com/ddobbelaere/nnue-pytorch/tree/surgery. Run e.g. python perform_surgery.py ../research/run84run3/epoch=245.ckpt nn-epoch245_no_clipping.nnue/pt --features="HalfKP^".

As mentioned, if clamping occurs in the affected rescaled input layer neurons (three have been rescaled), the transformation is not exact. However, maybe there is ELO gain as now no more weights clipping is present.

The new "operated" net can be found here: https://drive.google.com/file/d/1YXt2XUlNTf-JMeS4875-VWAJ2UaI4kJW/view?usp=sharing I don't have much experience to perform an ELO test for it though...

Comparison between checkpoints shows that the clipping problem is "resolved". Note that weights/biases input neurons [58, 155, 199] (counting from 0) have been rescaled with resp. 1.9851803516778421, 1.2522607037401574 and 1.8282710999015748.

epoch=245 ckpt_input-weights nn-epoch245_no_clipping pt_input-weights epoch=245 ckpt_l1-weights-histogram nn-epoch245_no_clipping pt_l1-weights-histogram

EDIT: I fixed a "bug" (the rescaling was around a factor 2 too high, it's not necessarily "bad", but unnecessary). The new uploaded net is "fixed".

glinscott commented 3 years ago

I've kicked off a fishtest run for this here: https://tests.stockfishchess.org/tests/view/600d8adf735dd7f0f0352c52. Fingers crossed :).

Btw, I've also put together a branch to do quantization aware training: https://github.com/glinscott/nnue-pytorch/compare/quant_train?expand=1, which should allow us to train a net that is quantization aware. Haven't had time to do a "big" run on it yet, but it's able to learn at the beginning at least.

vondele commented 3 years ago

And if I understand correctly, the reference clipped net run I did for this net this morning https://tests.stockfishchess.org/tests/view/600d41d7735dd7f0f0352c2d

ddobbelaere commented 3 years ago

Doesn't look too promising...

I've checked with https://hxim.github.io/Stockfish-Evaluation-Guide/ (tip: open both in two tabs and switch between them) and the only thing that is different for starting position (for input neurons) is neuron 156 (155 if you count from 0) and its brother 412 (411) which go resp. from 117 --> 127 and 75 --> 90. So the rescaling by 1.25 is working fine, but for neuron 156 already clamping is happening (in the starting position). The changes in L1/L2 output neurons are limited to +/-1 (one occasional +2), so everything seems to "work fine", except for the clamping of course.

Note that the weights of neurons 59 (58) and 200 (199) look completely blue (all negative), so less chances of clamping with lots of pieces on the board (they are indeed not excited in the starting position, according to the Evaluation Guide).

ddobbelaere commented 3 years ago

We might be witnessing some bias here though, as epoch 245 has been selected for max. ELO performance over all "clipped nets", whereas the non-clipped counterpart doesn't have that privilege (e.g. maximizing over all non-clipped operated nets would be more fair).

vondele commented 3 years ago

also, while this might have improved the clipping, this might have changed the quantization error at the same time.

Meanwhile @Sopel97 asked to provide a net with more clipping, so I've uploaded https://drive.google.com/file/d/1iOvqBPVeFDWvP5m36YFrBhDmJ3GT_oaM/view?usp=sharing where there are about 242 clipped weights. Yet the net tests well (3 run86/run1/nn-epoch369 : -15.1 3.7 )

ddobbelaere commented 3 years ago

also, while this might have improved the clipping, this might have changed the quantization error at the same time.

Yes, and also clipping only has effect for very few weights (7), whereas clamping now negatively affects up to 3 2 32 = 192 terms in L1 layer calculations, depending on position.

Sopel97 commented 3 years ago

I've made a stockfish branch that uses floats here https://github.com/Sopel97/Stockfish/tree/float_nnue. It's slower but reasonably so (a few times). Doesn't have AVX2 affine transform implementation but should be fine. To get a compatibile net just remove all conversion to int from serialize.py and save the floats directly. I'm fairly confident the implementation is correct as a short sanity checks shows the playing strength at fixed nodes is comparable to master.

vondele commented 3 years ago

copy from discord for the records:

results with the float branch at fixed (100k) nodes vs our reference int8:

   # PLAYER    :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)
   1 ref       :     0.0   ----  8027.0   16014    50      69
   2 float     :    -0.9    3.5  7987.0   16014    50     ---

basically, the same within error bars. That would suggest that the Elo impact of the quantization error is really low. There is one caveat, I picked the best net so far (run84/run3/epoch245), which maybe is a best net because it has 'naturally/by luck' low quantization error (it certainly has very little clipping). I might do a second test with a more random net. Update: ran a second test (40k nodes) with a more random net (still good, but not best net of the run). Result essentially the same:

   # PLAYER             :  RATING  ERROR  POINTS  PLAYED   (%)  CFS(%)
   1 r86r1e523-int      :     0.0   ----  8220.0   16399    50      67
   2 r86r1e523-float    :    -0.9    4.0  8179.0   16399    50     ---
ddobbelaere commented 3 years ago

Thanks, that's pretty conclusive to me.

As the suggestion in the opening post also seems to be bad (changing clamping behavior of input neurons is not to be recommended, judging from the "no clipping experiment"), I'll close this one.