lightvector / KataGo

GTP engine and self-play learning in Go
https://katagotraining.org/
Other
3.6k stars 569 forks source link

New architecture #703

Open aki65 opened 2 years ago

aki65 commented 2 years ago

After testing with the new architecture for some time, I am very impressed by its playing strength after such a relatively short training period. However, the new mish activation has proved to be quite a computational burden: It rules out fp16 calculation and int8 quantization completely due to lack of accuracy. And with fp32 it's around 15% slower than the same network with an activation avoiding transcendental functions.

Because of similar performance issues several networks including swish activation ( x sigmoid(x) ) have recently moved to hardswish ( x relu6(x+3)/6 ) with good results. So perhaps hardswish (or x * relu5(x+3)/5, which is closer to mish) might be a good activation choice for katago, too ...

lightvector commented 2 years ago

Thanks for the feedback! I'm a little surprised that you report it rules out fp16 calculation. Could you explain a bit more? For example, the implementation right now in https://github.com/lightvector/KataGo/tree/gpu-backend-cleanup supports FP16 for CUDA and OpenCL and Eigen. Are you observing very large accuracy degeneration in this current implementation?

Also, how well do you think it would work to fine-tune the mish activation network to convert it to a new activation function? For example having it use the activation function (1-alpha) mish + alpha x * relu5(x+3)/5 and then smoothly transitioning alpha from 0 to 1 over the course of several millions of training samples? If you're already doing quantization for int8 and various experiments of your own I was wondering if this is something you had also considered, given that we have an almost trained net that uses mish already.

aki65 commented 2 years ago

Could you explain a bit more?

I'm sorry, I should have been more precise here: All my fp16 tests were run with tensorflow lite as backend on android. There I indeed got significant accuracy degeneration (up to 7% deviation in winrate for example), which I never encountered with the old architecture, so I attributed it to mish. This might generalize to other backends (if it's due to fp 16 limitations) or not (if it's due to over-aggressive optimization or bad implementation in tensorflow lite), I don't know. So my "...rules out fp16" currently refers only to the android platform, but that's already a pity, since fp16 is particularly interesting on weaker hardware.

Also, how well do you think it would work to fine-tune the mish activation network to convert it to a new activation function?

I think, your approach is very promising. In my tests I replaced mish by several approximations in the current network, primarily to compare performance, but I also glimpsed at the deviation of the results. And I observed (with fp32 computation) that the results vary pretty smoothly with the activation function. So small steps in alpha should lead to small errors which the network should learn to compensate pretty quickly. Moreover, this shifting approach should show quickly how fast the network can "follow". So if it doesn't work, you would recognize that early, so even in the worst case not much training time would be wasted.

lightvector commented 2 years ago

That makes sense. Can you explicitly confirm that the 7% deviation in winrate between FP32 and FP16 and/or the average deviation across a set of the positions you were using to test is much smaller is if you run the network with your proposed relu5 replacement? Even without re-training the net for other activations, I think testing the deviation for other activations would clarify which activation functions have enough numerical precision to work in the framework you're using. It would be sad to try to train the replacement, only to discover that the new activation is still very bad in deviation.

If you can confirm, then I will investigate alternatives for the activation.

lightvector commented 2 years ago

In particular, how much smaller is FP16-FP32 difference with the alternative activation functions? And how much smaller is it with the original ReLU with the original 40b or 60b nets?

aki65 commented 2 years ago

I checked the following cases: 1) original 40b (so with ReLU) 2) original 18b (so with mish) 3) 18b with x * relu5(x+3)/5 replacing mish (results make no go sense, but still comparable)

As a measure of deviation between FP32 and FP16 I looked at the average relative difference in White's score lead over a set of sample positions: 1) 0.8% 2) 6.1% 3) 1.5%

Apart from the FP16 issue, I'm hoping that a simpler activation will make int8 quantization possible again. Unfortunately I can't check that in advance, because the errors in int8 quantization heavily depend on the value ranges of the tensors floating through the network during evaluation. In a trained network these intermediate values have some "go meaning", so they usually don't vary too wildly. But in the construction (3) these values are nonsense, so I can't predict the behaviour of int8 quantization from (3).

lightvector commented 1 year ago

Initial experiments migrating b18 network from Mish to other activation functions doesn't appear to be working well so far. I tried migration to hardswish since hardswish is still pretty close of a match and is directly supported in more APIs, and also relu, and as expected relu is even worse because it doesn't have the negative component, but even hardswish is not so great. This is the loss as of about 90% migration, relu about 50%. Policy and value loss: p0loss vloss

lightvector commented 1 year ago

@aki65 - I'm a bit more surprised now than before that you're seeing such large errors. I implemented in KataGo recently a command in https://github.com/lightvector/KataGo/tree/gpu-backend-cleanup to test FP16 errors:

./katago testgpuerror -config GTP_CONFIG_FILE.cfg -model MODEL.bin.gz

And most users report errors that are a bit higher than I would like, but still probably within the acceptable range. Mostly, the absolute difference in winrates appears to less than 1% for almost everyone at the 99th percentile of errors, and for score the 99th percentile absolute error is less than 0.5 points for pretty much everyone. For the policy mass on the top move, the 99th percentile absolute difference is less than 1% for most people as well.

How are you measuring your statistics? For example, when you say average relative difference, do you mean that if the FP32 score were 0.05 (because the position is a relatively even position) and the FP16 score were 0.15, that would be a 200% error because it is +200% larger than it should be? I would consider that case not a very worrying difference, since in absolute terms it would be only 0.1 points, unlikely to harm the search much especially once averaged out through MCTS. Or are you calculating it some other way?

Additionally, does the error go down a lot if you force your model implementation to compute the activation functions in FP32 instead of in FP16, but still leaving all the convolutions and other operations in FP16? I'm not familiar with mobile, but on PC, it tends to be the case that memory access is a big part of the compute cost in GPUs, and therefore even though activation functions like Mish are still expensive, adding an extra FP32 <-> FP16 round trip conversion doesn't add any memory access pressure so the marginal cost isn't as high as might be expected.

The CUDA backend actually does this right now, computing the activation function in FP32 despite the rest of the net being FP16, and for most users still seems to achieve performance of the new net that is comparable to 40b. See here for the computation: https://github.com/lightvector/KataGo/blob/gpu-backend-cleanup/cpp/neuralnet/cudahelpers.cu#L38-L43

aki65 commented 1 year ago

Thank you very much for investigating this issue so thoroughly. In an earlier post I said

This might generalize to other backends (if it's due to fp 16 limitations) or not (if it's due to over-aggressive optimization or bad implementation in tensorflow lite), I don't know.

With all the results, that you gathered from other backends, it's now obvious that the latter is true. I was using the following formulation in my tensorflow model:

activation_x = x * tf.math.tanh(tf.math.softplus(x))

As softplus is a built-in tensorflow operation, I trusted that all numerical issues would be properly handled behind the curtains by the implementation. But since your last post I'm suspicious about that. So, looking at the discussion about HALF_MAX on discord, I introduced a cap at the model level:

capped_x = CAP - tf.nn.relu(CAP - x) activation_x = x * tf.math.tanh(tf.math.softplus(capped_x))

This looks a little awkward, but conditionals are not supported by most NPUs on android. I started with CAP=10, which didn't change much, but when I reduced CAP, the deviation between fp16 and fp32 went down, and at CAP=5.5 it was in the range reported for OpenCL from other platforms. So, although hard to believe, there is probably something fishy in the tensorflow implementation here.

The spin-off from this experiment is that I can probably live with the capped formula as a workaround, since the error, that it introduces (leaving numerics aside), is way below fp16 precision.

lukaszlew commented 1 year ago

Looking at the loss curves, do I understand correctly that that the migration to alternative nonlinearities did not cause the model to immediately lose quality? I.e. some additional training caused that?