google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

Difference between "gradient_descent_mse" and "gradient_descent_mse_ensemble" #72

Closed lionelmessi6410 closed 3 years ago

lionelmessi6410 commented 4 years ago

I am truly confused about the difference between gradient_descent_mse and gradient_descent_mse_ensemble.

In the original NNGP and NTK paper, the author mentioned that we can use Gaussian Process methods to make a prediction on test data, i.e. kernel kernel inverse y. In addition to the mean prediction, we can also calculate the covariance matrix. Based on my understanding, this corresponds to gradient_descent_mse_ensemble.

In contrast, in the GitHub page, you indicatedgradient_descent_mse as inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite, and I also traced the source code and the description looks similar to the ensemble method. I cannot understand how to create a single infinite width network, since both NTK and NNGP require marginalizing out the initialization. Can you explain how it works and the difference between the abovementioned methods?

If it is possible, can you point out these two methods correspond to which equation in this paper?

Many thanks for your kindly reply.

sschoenholz commented 4 years ago

Great question! The difference between gradient_descent_mse and gradient_descent_mse_ensemble is that the former function only computes the mean predictions whereas the latter function computes both the mean and the covariance of the outputs during training.

The calling conventions between the two are somewhat similar, for example:

k_train_train = kernel_fn(x_train, None, params)
k_test_train = kernel_fn(x_test, x_train, params)

predict_fn = predict.gradient_descent_mse(k_train_train, y_train)
fx_train_t, fx_test_t = predict_fn(t=1.0, k_test_train=k_test_train)

Will predict the mean of the function on training and test points after t = 1.0 time has passed.

By contrast,

predict_fn = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train)

fx_ensemble_train_t, fx_ensemble_test_t = predict_fn(t=1.0, x_test=x_test, get='ntk', compute_cov=True)

will compute the mean and covariance of the output following gradient descent training for t = 1.0 time. So,

fx_ensemble_train_t.mean == fx_train_t  # True!

Let me know if this is helpful or if you have any further questions about the prediction functions!

lionelmessi6410 commented 4 years ago

Thanks for your kind reply. It seems that the only difference between them is the covariance matrix. Does it mean that if I set compute_cov=False in predict_fn returned by gradient_descent_mse_ensemble, gradient_descent_mse_ensemble will be same as gradient_descent_mse? If so, which API has better performance, i.e. shorter computing time?

I noticed other attributes, fx_train_0 and fx_test_0 in gradient_descent_mse, representing the output of the network at t = 0 on training and test data, respectively. Based on my understanding, in the linearized neural networks, to get a precise approximation on the original network, tangent kernel and output at initialization, fx_train_0 and fx_test_0, are required.

However, in the infinite width limit, the tangent kernel converges to the deterministic kernel, thus there is no need to provide those values at initialization. I notice with an infinite width limit, you set the default value of both training and test data to 0., as following,

t = 1.0
k_train_train = kernel_fn(x_train, None, 'ntk')
k_test_train = kernel_fn(x_test, x_train, 'ntk')

predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train)
fx_train_t, fx_test_t = predict_fn(t=t, fx_train_0=0., fx_test_0=0., k_test_train=k_test_train)

These setting also appear in another work Disentangling Trainability and Generalization in Deep Neural Networks. However, if I set them with the value at initialization, it will be,

y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
fx_train_t_0, fx_test_t_0 = predict_fn(t=t, fx_train_0=y_train_0, fx_test_0=y_test_0, k_test_train=k_test_train)

and

fx_train_t != fx_train_t_0
fx_test_t != fx_test_t_0

Why is this case? Moreover, the outputs(mean) of gradient_descent_mse_ensemble are very close to gradient_descent_mse with fx_train_0=0. and fx_test_0=0.. Does it imply if the infinite width limit case, we should better set those values to 0, while in the finite width network, we should provide the values at initialization?

romanngg commented 4 years ago

1) predict.gradient_descent_mse corresponds to Eq. 9 - 11 in https://arxiv.org/pdf/1902.06720.pdf, and predict.gradient_descent_mse_ensembe to Eq. 14 - 16.

2) Re predict.gradient_descent_mse docs: perhaps that was a poor choice of wording, and a single very wide network would be more appropriate. The message was intended to be that as the network gets wider, outputs of this functions will approximate continuous-time GD of the (non-linearized) network better and better. However, I think you can "create" a single infinite network by e.g. defining _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1)) and sampling once [fx_train_0, fx_test_0] ~ N(0, kernel_fn([x_train, x_test], None, 'nngp') - IIUC this gives you outputs of a single infinite network on train and test data.

3) Yes, for fx_train_0 == fx_test_0 == 0, both functions should give the same mean predictions. I think they should be very close compute-wise, if I had to guess I'd bet on gradient_descent_mse_ensembe being slightly faster due to not having to deal with a dummy all-zeros fx_test_0, but it's best to compare both under jit to be sure. If you find them dramatically different, please let us know! You also mentioned that gradient_descent_mse_ensemble and gradient_descent_mse with zero initial values give close predictions - are they very (numerical precision) close, or still substantially different? If they are substantially different, this could be a bug.

4) Yes, individual (finite or infinite) networks have non-zero initial outputs, hence for best approximation you should pass the initial fx_train_0 and fx_test_0. This is why you get different fx_train_t != fx_train_t_0 and fx_test_t != fx_test_t_0 (i.e. see that equations 9 and 11 depend on f_0(X) and f_0(x)). See an example for linearized training in https://github.com/google/neural-tangents/blob/master/examples/function_space.py.

5) gradient_descent_mse_ensemble makes inference with an infinite ensembe of infinite networks, and, as you correctly mentioned, marginalizes over initial values, hence does not accept them as parameters. Formally, [fx_train_0, fx_test_0] in this case is implicitly assumed to be mean-zero Gaussian with covariance kernel_fn([x_train, x_test], [x_train, x_test], 'nngp'). Note that in practice this requires that your top layer of the network is something like nt.stax.Dense or Conv. If your top layer is e.g. nt.stax.Relu, kernel_fn will compute the second moment of your outputs, but they won't be Gaussian, and gradient_descent_mse_ensemble will not apply.

6) To summarize:

If you are doing inference with an infinite ensemble of infinite networks:

If you are doing inference with a single network:

Lmk if this makes sense / if I missed any questions!

lionelmessi6410 commented 4 years ago

Thanks for your super clear and remarkable explanation! Now I finally understand what's going on with these methods. And thanks for you to point out the corresponding equations in the paper.

About point 3., I checked the predictions given by gradient_descent_mse_ensemble and gradient_descent_mse with zero initial values, and are they very (numerical precision) close. Basically, they are the same.

romanngg commented 3 years ago

You're welcome! Will close this as answered, but please don't hesitate to open new issues if you have other questions!