lightvector / KataGo

GTP engine and self-play learning in Go
https://katagotraining.org/
Other
3.56k stars 564 forks source link

Implement "Accelerating Self-Play Learning in Go" #307

Closed CGLemon closed 4 years ago

CGLemon commented 4 years ago

David J. Wu: Good for your excellent work. Your research is so helpful. Many go amateurs use KataGo to analyze they game. It also inspires many researches.

In the current, I tried to implement some method from the paper "Accelerating Self-Play Learning in Go". But my go program looks like not good. The final score and dynamic komi did not work well.

This is what I did. I designed a network. It predict final score, board ownership, move probabilities, and Multi-Label. Board ownership and move probabilities are same as the paper. Multi-Label is as same as the paper "Multi-Labelled Value Networks for Computer Go". But final score is a little different. It predict final score without komi. For example, In a game, the black score is 10, the white score is 9, komi is 7.5. The final score head predict 1 in the ideal situation.

I try to train 7x7 network. In the self play, there are three case

  1. 60% is setting the fair komi (fair komi is 9).
  2. 20% is that let the program adjust fair komi by itself.
  3. 20% is that let the program move random and then adjust fair komi by itself. After training, the final score head result is 9. This is good result in the fair games. But in the not fair game, the final score head result is 9 offen. The dynamic komi is not effective well, too. For example, the winrate is almost same in the komi 9 and komi 10.

There is all loss functions, https://github.com/CGLemon/TemplateGo/blob/se/train/src/Torch.cc#L771-L781

I hope that you may get me some tips. Thank you very much. -- Hung Tse, Lin

lightvector commented 4 years ago

Glad you found it interesting and helpful!

It sounds like you might have some sort of bug? If so, the bug might not be in your loss function, but could be in the way that you are generating your data, or using its output in self-play, or it might be a flaw in the design of how the semantics of the different pieces fit together. You will be way better at debugging than me because you are far more familiar with your code than I am. But if you like, here are some questions to help out as to possible places to look or to test for bugs:

Given that you are doing multi-labeled networks, I presume that you are not providing komi as an input to the net? If so, then have you double and triple checked that all details of your design are consistent with this? For example:

Alternatively, if you are providing komi as an input to the net, then instead some things to double and triple check:

And for that matter, are you handling draws and integer versus half-integer komi the way you intended, no off-by-one errors anywhere?

You certainly don't have to answer all of these questions to me or actually reply with answers, these are just places you might re-check if you suspect a bug or a flaw in the design somewhere.

lightvector commented 4 years ago

Could also be worth checking for sign errors. For example, if the neural net is receiving the board from the perspective of the player to move, are the komi values, the scores, calculation of the loss function flipped correctly based on the player to move? At inference time, are the outputs of the neural net also interpreted correctly taking into account the player to move?

(And obviously, again make sure you are consistent everywhere - it would potentially be a bug to present the neural net data from the perspective of the player to move, but to require it to predict scores that are not flipped according to the player to move).

CGLemon commented 4 years ago

Thank you for your suggestions. In my network, the komi is part of features. The Multi_labeled is depend on input komi. The center of labels is current komi.

I also asked same question for CGI author. He thought that the reason is there are too many same result in self-play games. It is true after checking the self-play games. I adjusted my self-play pipeline. Then it works better.

I also found another mistake after reading your suggestions. My final-score-head do not predicts final-score, but predicts score-on-the-board. The mistake method is still work but it is not effective. I will fix this mistake before next training.

On the next training, I want to combine multi-labeled and "katago-dynamic-komi-method"(Maybe you can name this method). The CGI author thought the multi-labeled could replace "katago-dynamic-komi-method". But on my past training, multi-labeled and "katago-dynamic-komi-method" effect differently for games. They are different advantages. I will be fully-training to check this ideal.

I still other questions for KataGo network struct.

  1. KataGo use global pooling to extract information form convolution layer. The other top engine(LeelaChessZero) uses SE Unit to extract information. Is it any disadvantageous to use SE Unit?
  2. In residual tower, not every residual blocks use global pooling. why?
  3. KataGo use "MatMulLayer" to adjust channels. why do not you use fully-connect?
  4. KataGo predicts probability for win, loss and no-result. why do not you use original AlphaGo method to predicts win rate?
  5. Look like the most of batchmorm layer are useless. I have no ideal why to do this.
lightvector commented 4 years ago

Sounds like those bugs are likely to account for the issues you were seeing. Hope the next run goes well! :)

  1. SE is basically global pooling where you multiply channels by the (sigmoided) result. KataGo's global pooling is one where you simply add them in as biases. Since computing arbitrary simple smooth transformations of values is pretty easy for a neural net, there really shouldn't be much difference between adding (along with RELU) versus multiplying by sigmoid.

    I stuck with my architecture because when I tried to implement SE in OpenCL and CUDA, I was not personally able with my GPU programming skills to make it as fast as a plain resnet, whereas a KataGo-style global pooling network I was able to to implement and make it about equally fast without too much trouble. But if you are able to get good performance either way, it probably doesn't matter much.

  2. Because I don't think it's necessary, and might be wasteful to have too much computation spent on it. Global-pooling-based techniques allow the neural net to take into account aggregated non-spatial statistics about the board that convolutions are unable to see because convolutions focus locally. Such as "am I ahead or behind" or "how early/late in the game is it" or "how many ko threats do I have" or "how many urgent places there are to play". This is why they are good. But, I don't think there are such a large number of relevant global properties that it matters to have so many of them, particularly if every global pooling layer is already computing 64 or however many different channels in parallel. There's only so much you can do if you're completely ignoring where things are located! So spending computation on convolutions - which encode the understanding tactics and fights and shapes and everything that depend on where stones actually are placed relative to each other - seems more important. Having global pooling every block is overkill.

  3. Fully connected is the same as matrix multiplication. A fully connected layer means every output value depends linearly on every input value. A matrix multiply means every output value depends linearly on every input value.

  4. Because the original AlphaGo method doesn't have a way to represent no-result. No-result is a possible outcome in Japanese rules. I could certainly have encoded it as a draw, but I was curious to separate it out, since it actually is incomparable to a draw and could be valued differently, and I was also scientifically curious to see how well a neural net could predict it given the rarity of occurrence. Probably draw should be separated out in KataGo too, and predicted separately from win and loss.

    By the way, again keep in mind that neural nets can easily do simple smooth bijective transformations of data. If the final output head is a tanh, versus if it is a sigmoid on a binary prediction, versus if it is anything else, it really, really doesn't matter, so long as you adjust the loss function to be mathematically equivalent and adjust hyperparmeters to make the gradients comparable. (Although, adding an entirely new category, like "draw" or "no-result" DOES potentially have a major effect).

  5. KataGo uses Fixup: https://arxiv.org/abs/1901.09321. This is a working technique to train networks without batch normalization. I have used it and it is successful, and for KataGo, seemed to have no disadvantages in the learning efficiency. (I did not even need to add any regularization like they reported in the paper, I only needed to add a gradient clip to restore full training stability). LC0 also abandoned batch norm in favor of batch renorm to fix problems that were caused by batch norm, but batch renorm is messy and complicated and I figured getting rid of batch norm entirely is far simpler.

    Training without batch norm is super nice. You no longer have to worry about side effects due to batch size and virtual batch size on the net - doing a batch of 128 versus doing two batches of 64 on different GPUs and adding the results up - now they are mathematically equivalent. You no longer have to worry about inference time statistics being different than training time statistics - which is why LC0 got problems with it. You no longer have to worry about the quality and noise in the statistics, if we switched to batch size 1 it would likely be pretty harmful for batch norm, but with Fixup it would probably work just fine. Training is actually much faster too, since batchnorm is computationally expensive. Not such a big deal because self-play cost is by far the dominant overall cost, but still, doing gradient updates faster for free is a small bonus. At inference time, both Fixup and BatchNorm are the same. You have a bias term and a scaling term. You multiply by scale and add bias. So the "BatchNormLayer" inside KataGo's C++ code has simply been repurposed to apply Fixup's bias and scale instead of BatchNorm's bias and scale.

CGLemon commented 4 years ago

Thank you for your reply. It is so useful for my future work. The research of KataGo is so important for other researchers. Hope you can create more powerful way for KataGo.