Closed lionelmessi6410 closed 3 years ago
Great question! The difference between gradient_descent_mse
and gradient_descent_mse_ensemble
is that the former function only computes the mean predictions whereas the latter function computes both the mean and the covariance of the outputs during training.
The calling conventions between the two are somewhat similar, for example:
k_train_train = kernel_fn(x_train, None, params)
k_test_train = kernel_fn(x_test, x_train, params)
predict_fn = predict.gradient_descent_mse(k_train_train, y_train)
fx_train_t, fx_test_t = predict_fn(t=1.0, k_test_train=k_test_train)
Will predict the mean of the function on training and test points after t = 1.0 time has passed.
By contrast,
predict_fn = predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train)
fx_ensemble_train_t, fx_ensemble_test_t = predict_fn(t=1.0, x_test=x_test, get='ntk', compute_cov=True)
will compute the mean and covariance of the output following gradient descent training for t = 1.0 time. So,
fx_ensemble_train_t.mean == fx_train_t # True!
Let me know if this is helpful or if you have any further questions about the prediction functions!
Thanks for your kind reply. It seems that the only difference between them is the covariance matrix. Does it mean that if I set compute_cov=False
in predict_fn
returned by gradient_descent_mse_ensemble
, gradient_descent_mse_ensemble
will be same as gradient_descent_mse
? If so, which API has better performance, i.e. shorter computing time?
I noticed other attributes, fx_train_0
and fx_test_0
in gradient_descent_mse
, representing the output of the network at t = 0 on training and test data, respectively. Based on my understanding, in the linearized neural networks, to get a precise approximation on the original network, tangent kernel and output at initialization, fx_train_0
and fx_test_0
, are required.
However, in the infinite width limit, the tangent kernel converges to the deterministic kernel, thus there is no need to provide those values at initialization. I notice with an infinite width limit, you set the default value of both training and test data to 0., as following,
t = 1.0
k_train_train = kernel_fn(x_train, None, 'ntk')
k_test_train = kernel_fn(x_test, x_train, 'ntk')
predict_fn = nt.predict.gradient_descent_mse(k_train_train, y_train)
fx_train_t, fx_test_t = predict_fn(t=t, fx_train_0=0., fx_test_0=0., k_test_train=k_test_train)
These setting also appear in another work Disentangling Trainability and Generalization in Deep Neural Networks. However, if I set them with the value at initialization, it will be,
y_train_0 = apply_fn(params, x_train)
y_test_0 = apply_fn(params, x_test)
fx_train_t_0, fx_test_t_0 = predict_fn(t=t, fx_train_0=y_train_0, fx_test_0=y_test_0, k_test_train=k_test_train)
and
fx_train_t != fx_train_t_0
fx_test_t != fx_test_t_0
Why is this case? Moreover, the outputs(mean) of gradient_descent_mse_ensemble
are very close to gradient_descent_mse
with fx_train_0=0.
and fx_test_0=0.
. Does it imply if the infinite width limit case, we should better set those values to 0, while in the finite width network, we should provide the values at initialization?
1) predict.gradient_descent_mse
corresponds to Eq. 9 - 11 in https://arxiv.org/pdf/1902.06720.pdf, and predict.gradient_descent_mse_ensembe
to Eq. 14 - 16.
2) Re predict.gradient_descent_mse
docs: perhaps that was a poor choice of wording, and a single very wide
network would be more appropriate. The message was intended to be that as the network gets wider, outputs of this functions will approximate continuous-time GD of the (non-linearized) network better and better. However, I think you can "create" a single infinite network by e.g. defining _, _, kernel_fn = stax.serial(stax.Dense(1), stax.Relu(), stax.Dense(1))
and sampling once [fx_train_0, fx_test_0] ~ N(0, kernel_fn([x_train, x_test], None, 'nngp')
- IIUC this gives you outputs of a single infinite network on train and test data.
3) Yes, for fx_train_0 == fx_test_0 == 0
, both functions should give the same mean predictions. I think they should be very close compute-wise, if I had to guess I'd bet on gradient_descent_mse_ensembe
being slightly faster due to not having to deal with a dummy all-zeros fx_test_0
, but it's best to compare both under jit
to be sure. If you find them dramatically different, please let us know! You also mentioned that gradient_descent_mse_ensemble
and gradient_descent_mse
with zero initial values give close predictions - are they very (numerical precision) close, or still substantially different? If they are substantially different, this could be a bug.
4) Yes, individual (finite or infinite) networks have non-zero initial outputs, hence for best approximation you should pass the initial fx_train_0
and fx_test_0
. This is why you get different fx_train_t != fx_train_t_0
and fx_test_t != fx_test_t_0
(i.e. see that equations 9 and 11 depend on f_0(X)
and f_0(x)
). See an example for linearized training in https://github.com/google/neural-tangents/blob/master/examples/function_space.py.
5) gradient_descent_mse_ensemble
makes inference with an infinite ensembe of infinite networks, and, as you correctly mentioned, marginalizes over initial values, hence does not accept them as parameters. Formally, [fx_train_0, fx_test_0]
in this case is implicitly assumed to be mean-zero Gaussian with covariance kernel_fn([x_train, x_test], [x_train, x_test], 'nngp')
. Note that in practice this requires that your top layer of the network is something like nt.stax.Dense
or Conv
. If your top layer is e.g. nt.stax.Relu
, kernel_fn
will compute the second moment of your outputs, but they won't be Gaussian, and gradient_descent_mse_ensemble
will not apply.
6) To summarize:
If you are doing inference with an infinite ensemble of infinite networks:
gradient_descent_mse_ensemble
and gradient_descent_mse
with fx_train_0 = fx_test_0 = 0
and should get the same predictions. Note that if you're computing the kernel using kernel_fn
from nt.stax
, you can check the attribute kernel_fn(x1, x2).is_gaussian
to know if outputs are mean-zero Gaussians or not.mu(.)
, gradient_descent_mse_ensemble
will be wrong, and I think you could use gradient_descent_mse
and pass fx_train_0 = mu(x_train), fx_test_0 = mu(x_test)
. FYI we usually work only with mean-zero Gaussians in this library, so I may be wrong here / there might be bugs we haven't discovered. If you are doing inference with a single network:
gradient_descent_mse
and pass fx_train_0 = f_0(x_train), fx_test_0 = f_0(x_test)
for best possible linearization approximation. If you also want this approximation to match the nonlinear dynamics, make the network wide and make sure the top layer is nt.stax.Dense
etc, since I'm not sure how well the theory/practice works for non-Gaussian outputs, and perhaps someone else could comment more here.fx_train_0
and fx_test_0
from a mean-zero Gaussian with covariance given by the NNGP of the network, gradient_descent_mse
should give you linearized = non-linearized training dynamics.Lmk if this makes sense / if I missed any questions!
Thanks for your super clear and remarkable explanation! Now I finally understand what's going on with these methods. And thanks for you to point out the corresponding equations in the paper.
About point 3., I checked the predictions given by gradient_descent_mse_ensemble
and gradient_descent_mse
with zero initial values, and are they very (numerical precision) close. Basically, they are the same.
You're welcome! Will close this as answered, but please don't hesitate to open new issues if you have other questions!
I am truly confused about the difference between
gradient_descent_mse
andgradient_descent_mse_ensemble
.In the original NNGP and NTK paper, the author mentioned that we can use Gaussian Process methods to make a prediction on test data, i.e. kernel kernel inverse y. In addition to the mean prediction, we can also calculate the covariance matrix. Based on my understanding, this corresponds to
gradient_descent_mse_ensemble
.In contrast, in the GitHub page, you indicated
gradient_descent_mse
as inference with a single infinite width / linearized network trained on MSE loss with continuous gradient descent for an arbitrary finite or infinite, and I also traced the source code and the description looks similar to the ensemble method. I cannot understand how to create a single infinite width network, since both NTK and NNGP require marginalizing out the initialization. Can you explain how it works and the difference between the abovementioned methods?If it is possible, can you point out these two methods correspond to which equation in this paper?
Many thanks for your kindly reply.