google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 225 forks source link

Weight Evolution and Predictions from Weights #34

Open shannon63 opened 4 years ago

shannon63 commented 4 years ago

Hi all,

I would like to (analytically) compute the evolution of the weights under the linearized dynamics (i.e., Eqn. (8) in https://arxiv.org/pdf/1902.06720.pdf) and use the resulting weights after t "steps" of gradient flow to make predictions on the training data. More specifically, I would like these predictions to match the predictions obtained by solving the function-space dynamics on the training data (Eqn. (9) in the paper).

To do this, I modified gradient_descent_mse() in predict.py to implement Eqn. (8). Specifically, I added the function

def predict_params_using_kernel(dt, fx_train=0.):
  gx_train = fl(fx_train - y_train)
  dfx = inv_expm1_dot_vec(gx_train, dt)
  dfx = np.dot(Jacobian_f0, dfx)
  return params0 - dfx

where Jacobian_f0 is the Jacobian wrt to the parameters of the NN at initialization, evaluated on the training data.

With the resulting parameters, params_t, converted back to the appropriate pytree strcuture, I then compute predictions on the training data by calling apply_fn(params_t, x_train).

Unfortunately, this does not seem to result in sensible predictions, since the parameters explode, i.e., become large in magnitude, for even small t, and don't match the predictions obtained by solving the function-space dynamics--even on the training data. I am aware that the mapping between parameter states and function predictions is not bijective, but shouldn't the parameters obtained from Eqn. (8) lead to the same predictions as Eqn. (9)?

NB: I did confirm that pre-multiplying dfx = np.dot(Jacobian_f0, dfx) by the transpose of Jacobian_f0 does yield the same matrix as calling the inbuilt function predict_using_kernel().

EDIT: I forgot to mention that I of course also modified the arguments of the gradient_descent_mse() to gradient_descent_mse(g_dd, y_train, params0, Jacobian_f0, g_td=None, diag_reg=0.) (i.e., I added params0, Jacobian_f0).

Any help would be much appreciated!

Thank you!

sschoenholz commented 4 years ago

Hey! I looked into it a bit and I think the approach you're describing should work. However, there might be some subtlety involved in making sure all of the shapes work out. Also, in general, I think it should be a lot more efficient to use JAX's vjp function rather than to instantiate the jacobian directly.

Here's a colab notebook where I added the weight computation to the gradient_descent_mse function and confirmed that for a simple FC network the analytical weights agree with the empirical weights. Let me know if this helps or if you have any questions that I can help resolve. Perhaps we should build this functionality into the prediction functions directly...

shannon63 commented 4 years ago

Thank you so much! I really appreciate the worked example.

I did come across something odd though. Specifically, I noticed that if you make the NN less wide, for example, by letting

init_fn, apply_fn, kernel_fn = stax.serial(stax.Dense(256), stax.Erf(), 
                                         stax.Dense(256), stax.Erf(),
                                         stax.Dense(256), stax.Erf(), stax.Dense(1))

then the (linearized) parameter-space and (linearized) function-space predictions diverge after about 1000 steps of gradient descent (with learning rate 1e-1). I don't see how this would result from the math in the paper--am I missing something or could there be a numerical/implementation reason for this?

Please see this colab notebook for a worked example that demonstrates the phenomenon.

sschoenholz commented 4 years ago

Good question! I'm not sure I have a totally satisfactory answer, but my guess is that the function-space -> parameter space map induced by the Jacobian breaks down when the number of examples becomes comparable to the width of the network (it seems like all of the right-singular vectors of $J^T$ should be nonzero). One way that you can check that something like this is likely going on is to reduce the number of training examples (to maybe 10 examples) and you'll see that the two curves end up getting much closer together.

I haven't thought about this too carefully, though, so this is a low-medium confidence answer :).

shannon63 commented 4 years ago

Thanks for the suggestion! I tried reducing the number of samples while keeping the width constant, but the predictions from the weight-space evolution still don't follow the function-space predictions.

A potentially related observation I made is that the expm1_dot_vec() function may not behave the way it should. As I understand it, expm1_dot_vec(v, dt) implements the operation

[exp(-\Theta * dt / n) - 1] * v = [Q * exp(-\Lambda * dt / n) * Q^{-1} - 1] * v

where Q * Lambda * Q^{-1} is the eigendecomposition of the NTK \Theta, v is some vector (in the code, v would be gx_train = f(X) - Y), dt = lr * t, and n is the number of samples in the training data set. It should then be true that for dt -> \infty, exp(-\Theta * dt / n) - 1 -> -1, and hence for dt -> \infty, we should have that expm1_dot_vec(v, dt) -> v. However, in the notebook you shared with me, this does not seem to be the case. Adding print(expm1_dot_vec(np.ones_like(gx_train), 1e26)) to train_predict() yields values significantly unequal to -1.

Am I misunderstanding how the expm1_dot_vec(v, dt) function works? Is it not supposed to converge to v as dt becomes large? Could this be related to the convergence behavior of the parameters?

Thank you for your help!

romanngg commented 4 years ago

1) Re expm1, I think the way you write is correct, but perhaps very small / slightly negative eigenvalues due to numerical errors mess up outputs, and increasing diagonal regularization would help?

2) Re discrepancy, I think you can substitute trained parameters of the linearized network into the linearized network, i.e. Eq (5) from https://arxiv.org/pdf/1902.06720.pdf, and this should yield same as Eq. (9). But IIUC you are applying the non-linear network f(params_t, ...), and AFAIK this should only agree in the very wide limit (eq (17)), but for narrow networks f(params_t, ...) != f_lin(params_t, ...).

3) On a side note, https://github.com/google/neural-tangents/commit/a76bbb494f19af4f8c9c1a1b0904e91b105f769e has changed the predict API significantly - please see the new docs on it (https://github.com/google/neural-tangents#package-description and https://neural-tangents.readthedocs.io/en/latest/neural_tangents.predict.html).

Let me know if any of this helps!