lightvector / KataGo

GTP engine and self-play learning in Go
https://katagotraining.org/
Other
3.57k stars 564 forks source link

Fixup vs Batch Renormalization #689

Open CGLemon opened 2 years ago

CGLemon commented 2 years ago

This is the training result in my project. The fixup is Fixup Initialization (with gradient clipping). The renorm is Batch Renormalization. The two network both use the same training data set, learning rate schedule and training steps. The final result shows that the fixup is weaker than renorm. Does the Kata Go do the Batch Renormalization experiment? I want to compare it with Kata Go. Thanks!

1. renorm(B) vs fixup(W)

Black is renorm and White is fixup . Played 20 games.

renorm(B) fixup(W)
W/L 13/7 7/13
winrate 65% 35%


2. renorm(W) vs fixup(B)

Black is fixup and White is renorm . Played 20 games.

renorm(W) fixup(B)
W/L 15/5 5/15
winrate 75% 25%


3. Final result

renorm fixup
W/L 28/12 12/28
winrate 70% 30%
lightvector commented 2 years ago

Thanks for the result! I haven't done detailed strength testing, but yes your results are consistent with some more recent neural net training. Fixup underperforms batchnorm a bit. So I recommend avoiding Fixup, which your results confirm as well.

The thing about batch renorm that I've had a little bit of trouble with is tuning the schedule for r and d clip parameters and getting the training to be stable at high learning rates. Also sometimes the batch norm running averages of beta and gamma are poor at inference time when the learning rate is high at medium to high learning rates (presumably because the running averages are inaccurate when the net is changing so fast), which makes validation and testing frustrating.

As a result KataGo's candidate new net (trained in pytorch-rewrite branch and likely to replace the 40b net in a month or two) is using batch norm (batch norm, not batch renorm) and there is only one batch norm layer in the entire net, at the end of the residual block stack just before the output heads. Surprisingly, I've found that having one batchnorm layer alone recovers most of the benefit of batch norm (but not all, so good chance that properly-tuned batch renorm would be a little better!). Then because there is only a single batch norm layer at the end of the net, except for the output heads, all the activations in the net are identical between training and inference. So we can train a second set of output heads to predict the output given the activations without batch norm. E.g. if X is the final trunk encoding after residual blocks, we train policy head F as F(Batchnorm(X)) and use F to drive the gradients but we also jointly train policy head G as G(X), and at inference time we can just directly use G.

(Also, during initialization, everywhere where one would normally add a batch norm layer in the net in the residual blocks, I add a fixed scalar multiplier which mimics in expectation what the batch norm division by sigma would do if we had a batch norm layer there. This makes the training of the net stable despite having no batchnorm in the blocks.).

Anyways, if you have tips on tuning batch renorm hyperparameters for very long runs, for high learning rates, or know how it should be set long-term within RL, maybe that would still be better than this two-headed scheme.

Edit: fixed some errors regarding gamma vs sigma, minor edits.

lightvector commented 2 years ago

Here's an example plot I dug up from some old short training experiments, comparing fixup (blue) and batch norm (orange). No Elo testing, just visualizations of the policy loss during training. These are of course far from conclusive, but they do show that the batch norm training loss in a short training run is much better than the fixup equivalent given similar short training, and when the learning rate is dropped (the discontinuities near the end of both runs are 2x LR drops, the fixup one was dropped several times), it improves by a larger amount.

p0loss

So this is also consistent with your own Elo testing results, even if it is a very limited data point. (edit: The current one-batch-norm-layer + two-headed scheme is not compared in that plot. I need do some fresh runs using the latest architecture just to get a clean comparison on that, I don't have a clean comparison right now).

CGLemon commented 2 years ago

Thanks for your information. I still have a question. You said that it is necessary to schedule the batch renorm parameters. I just pick up the value from the origin paper and it seems to work. Is the tuned schedule significantly better than untuned one? Additionally, the Lc0 project select the fixed parameters for batch renorm (if I am right). So I guess the parameters value is not very important?

CGLemon commented 2 years ago

I have another question which is not about batch renorm. I find that my last network is bad at life and death in the big dragon. The following SGF file is a example. My bot thought the white dragon is alive.

dragon-bug.txt

Do you any helpful method to deal with it except for RL? Thanks!

lightvector commented 2 years ago

I don't know of a good method to deal with it. Even RL-trained nets can have a lot of difficulty with large dragons - KataGo can misevaluate them too. They require integrating a lot of very precise information across a large area, and the number of permutations of different kinds of dragons and shapes is large so you need a lot of data.

One thing to be aware of is that a conv net that is too shallow will be unable to learn large dragons even with unlimited data. So for example, a 6 block net will never be capable of correctly judging some dragons on 19x19, because 6 blocks is around 12-14 conv layers, which if using 3x3 convolutions means that information can only propagate through convolutions only 12-14 spaces. In practice, I think you also usually get a little bit less than the theoretical max, e.g. I wouldn't be surprised by 60% or 70% efficiency, and of course the longer distance the propagation the fewer layers are left over for complex local calculations. By the time you reach 15 to 20 blocks though, I think this mostly stops being a concern for 19x19.