lightvector / KataGo

GTP engine and self-play learning in Go
https://katagotraining.org/
Other
3.59k stars 567 forks source link

Non-gpool variation #376

Open usidedown opened 3 years ago

usidedown commented 3 years ago

I'm wondering if there's a mature KataGo variation that doesn't use Global Pool. According to the paper the gpool variations train better, but they come at some cost during inference. Is there a model trained significantly long without using this operation?

lightvector commented 3 years ago

No, there is no such model, because, there is no reason to remove a cheap addition that majorly improves the neural net.

They don't cost much during inference. Or to be more precise, the global pooling channels are done only in a limited number of residual blocks, the channels for pooling are a limited number and that limited number is done instead of some normal convolution channels rather than in addition to them (so, holding the total convolution flops fixed), and so the cost is fairly small, while the gain is large.

If you really cared about being equal in cost, I think you could probably remove about 1 block or so, or maybe a little more blocks for larger nets. So, for example instead of a 20 block network, you'd have a 19 block network, and the 19 block network with pooling I think should outperform the 20 block network without pooling.

lightvector commented 3 years ago

For reference, global pooling is how the neural net very likely handles things like:

The common theme between all of these is that they require a local analysis to take into account something that is either summed or accumulated globally, or is a long-distance interaction that doesn't require any particular spatial relationship. If you had only convolutional layers, transferring this information would be much less efficient and also much harder to be uniform/unbiased (e.g. having a distant ko threat "count" just as much as a medium-distant ko threat when judging a possible ko fight). The net would likely be worse at all of the above, simply because convolution is the wrong operation to compute them.

usidedown commented 3 years ago

Thank you for your response! I'm trying to run KataGo in a different settings (not on GPU), and my speed-strength tradeoff is slightly different. I'd like to ask a few more questions about gpool: Does gpool calculate channel mean/std on the fly or can it used constants? From the paper I think that no mean/std calculations are needed in inference but I wasn't sure from the code. (Of course, per-channel average pooling is still needed). Is the maxpooling really necessary? Isn't avgpooling enough?

If I changed gpool just slightly, will it be possible to train the network using the current self-play games?

lightvector commented 3 years ago

What setting are you running in such that a few pooling operations is expensive?

You're going to run into the issue where the value head relies on global pooling too, and this time not as an incidental component, but it's the entire design of the head (pool some features over the entire space of the board instead of having a huge fully connected network). The policy head also intrinsically relies on the global pooling operation to decide when it is time to pass or not (because "pass", unlike other moves, is not associated with any particular location, so can't be computed convolutionally). So unless you're entirely redesigning the heads, you're going to have to mechanically support these operations for the heads (they are much cheaper than equivalent-sized fully-connected networks for the same number of channels), even if you don't want to have pooling within residual blocks,

Max pooling in addition to average pooling is just a historical choice. It probably would be fine without it, all the same it probably also does help to have it (e.g. there are things that behave "max-like" in Go, for example the value of sente is typically going to be about the value of the maximum move anywhere else on the board).

lightvector commented 3 years ago

I think you also misunderstand what the global pooling is? There is no mean or standard deviation - that's what batch norm is, and KataGo doesn't use batch norm any more, it hasn't for a long time.

It is just (average, scaled average, max pool across 19x19) -> matrix multiply -> broadcast to 19x19 -> add.

usidedown commented 3 years ago

I should've phrased myself better, I was asking about the scaling used for the scaled average. the scaled average is scaled by values noted as b_avg, b_min & b_max in the paper. I think these are constants, but can't tell for sure. In the value head scaling uses b_avg and sigma values. Again, I think these are constants but couldn't tell for sure.

lightvector commented 3 years ago

Ah, thanks I understand what you meant now. Yes, those are constants. They're just hardcoded in the code.