google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 226 forks source link

Possible bug in calculation of the value of the function at time t #126

Open mohamad-amin opened 3 years ago

mohamad-amin commented 3 years ago

Hey! I was reading through the code and I noticed that you're using element-wise exponential matrix here:

https://github.com/google/neural-tangents/blob/5f286b7696364217aa4a2d92378aabd0203a791e/neural_tangents/predict.py#L1180

Does this correspond to this equation (9) in the paper https://arxiv.org/pdf/1902.06720?

Screen Shot 2021-10-05 at 4 55 58 PM

If yes, shouldn't it be matrix exponential instead of element-wise exponential?

Thanks in advance.

sschoenholz commented 3 years ago

Great question! We move into a diagonal basis before applying the exponential in which case elementwise exponential and matrix exponential agree. In other words, we write $\exp(At)v = U\exp(Dt)U^Tv$ where $A = UDU^T$ is the diagonalization of $A$.

Let me know if that helps, or if you have any other concerns.

All the best, Sam

mohamad-amin commented 3 years ago

Thanks! Sorry for not paying attention to this.

It makes sense to me now, but I'm a bit concerned about the PSD correction that you apply to the kernel matrix here. In practice, as we care about the converged function, it would be pleasant to take the limit of t in infinite, which translates to choosing a really large t here. However, this correction, combined with multiplication with a really large t might cause approximation issues leading to the final learnt function not being as close as it should be to the finite network. Am I correct here? If yes, do you have any thoughts on that?

Sincerely, Amin

romanngg commented 3 years ago

IIUC you'd want to pass t=None, which we treat as symbolic infinity and choose a code path that does the simplified expression like Eq 16 in https://arxiv.org/pdf/1902.06720.pdf

mohamad-amin commented 2 years ago

Thanks! Does this method (passing t=None) work for the finite width setup as well? In other words, does passing t=None when the width of the network is finite create a short-circuit for computing the neural network function? If not, are we able to approximate the neural network function with finite width using this repository? If yes, doesn't it suffer from the approximation problem that I mentioned in the last post?

romanngg commented 2 years ago

I believe the answer is yes, but for the linearization of the the neural network function (not the original neural network function; the wider the network is, the closer they are though).

If you pass t=None to the predict_fn https://github.com/google/neural-tangents/blob/7d01d6513bf7bce5d227aa9f223eb8353cc8c74b/neural_tangents/predict.py#L210 returned by nt.predict.gradient_descent_mse https://github.com/google/neural-tangents/blob/7d01d6513bf7bce5d227aa9f223eb8353cc8c74b/neural_tangents/predict.py#L51

then it will return the outputs of the linearized model at infinite time. Just make sure to pass the correct starting conditions fx_train_0, fx_test_0, and the finite-width empirical kernels k_train_train and k_test_train computed on your network at initialization.

Specifically, here's the function it will call for t=None, using Cholesky solver and no exponentiation: https://github.com/google/neural-tangents/blob/7d01d6513bf7bce5d227aa9f223eb8353cc8c74b/neural_tangents/predict.py#L143

For arbitrary non-mse loss, you can also use nt.predict.gradient_descent https://github.com/google/neural-tangents/blob/7d01d6513bf7bce5d227aa9f223eb8353cc8c74b/neural_tangents/predict.py#L266 but here there are indeed no shortcuts for t=None, and it will just run the ODE solver until convergence (which may not even happen).

mohamad-amin commented 2 years ago

Thanks!

One other thing that caught my eyes is that theorem 2.1 (which I think is the proof of your statement) only applies for cases when $n >= N$ for some $N$ in $\mathbb{N}$. But if $n < N$, there is no guarantee for any convergence or asymptotic approximation. Is there any easy approach to check if our network is consistent with this condition and figure out the $N$ for one's particular architecture that is being used?

Sicnerely, Amin

romanngg commented 2 years ago

As far as I understand, for specific finite width n, there will always be a mismatch between your network and its linearization, and we don't have super practical bounds on the magnitude of the mismatch, so I would probably do binary search on width n in practice and see if it's large enough to give a close match.

In addition to this, you also need to satisfy conditions of Theorem 2.1 itself to expect that for large n the source and the linearized networks will behave closely. Beyond using full-batch gradient descent, checking that your inputs have the proper norm as stated in the theorem, you can also make sure your learning rate is not too big by calling nt.predict.max_learning_rate. Finally, the theorem also requires that your empirical NTK converges in the limit of large n; if you build your network out of nt.stax layers, it will be so, but for arbitrary functions it's not guaranteed (although, it should converge for a very wide range of architectures; see https://arxiv.org/pdf/1902.04760.pdf), so unfortunately this is one more way in which it's hard to certify convergence with certainty.

(CC @SiuMath who may know more about convergence and bounding the discrepancy)

mohamad-amin commented 2 years ago

Thanks!

Correct me if I'm wrong but isn't it mentioned in the formal version of the theorem in the appendix that the O(n^-2) rate of mismatch between the linearized network and the original network is only achieved for n > some N and not all n?

Also, doesn't Theorem 1 in https://arxiv.org/pdf/1806.07572.pdf imply that the NTK converges on any architecture (of course for continuously differentiable activation functions)? Update: I guess it doesn't cover "modern" layers, the most modern layer I know is pooling as I'm outdated on deep learning but I assume the new types of layers might not necessarily be represented by affine transformations and that's why they're not covered by this theorem, am I right?

SiuMath commented 2 years ago

Hi Amin,

In terms of

"One other thing that caught my eyes is that theorem 2.1 (which I think is the proof of your statement) only applies for cases when $n >= N$ for some $N$ in $\mathbb{N}$. But if $n < N$, there is no guarantee for any convergence. Is there any easy approach to check if our network is consistent with this condition and figure out the $N$ for one's particular architecture that is being used? "

The theoretical bound of $N$ depends on a lot of factors and is very weak, e.g. N can depend polynomially on the number of training samples $m$ and the inverse of the least eigenvalues of the infinite NTK, which can be very small. I haven't followed the line of research concerning the sharp bound of $N$, in some ideal cases, $N$ may depend linearly on $m$. As mentioned by Roman, if the learning rate is large, the so-called NTK dynamics may not hold.

The discrepancy bound $~n^{-1/2}$ between linearized and the original non-linear dynamics also requires N= poly(m). I am not aware of any work that could give a bound of $N$ that is practically useful. Grid search n + small learning rate may be your best bet if you would like to force your network into the linearized regime.

best, Lechao

On Mon, Oct 18, 2021 at 7:45 PM Mohamad Amin Mohamadi < @.***> wrote:

Thanks!

Correct me if I'm wrong but isn't it mentioned in the formal version of the theorem in the appendix that the O(n^-2) rate of mismatch between the linearized network and the original network only achieved for n > some N?

Also, doesn't Theorem 1 in https://arxiv.org/pdf/1806.07572.pdf imply that the NTK converges on any architecture (of course for continuously differentiable activation functions)?

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/126#issuecomment-946251681, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGC3MA4KL3AGQIFR34LUBX3UHSWSDANCNFSM5FM76L5Q . Triage notifications on the go with GitHub Mobile for iOS https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675 or Android https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub.