google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Modelling Oscillatory Functions #77

Open Cyberface opened 4 years ago

Cyberface commented 4 years ago

Hi,

I'm not sure this is the write place to ask this kind of question but I see the issues are filled with questions so hope this is OK. Let me know if I should post this on SO or something (but I had a quick look and could see a tag for this).

I think this package is very interesting and the example notebooks you have make it much easier to learn how to use it so thanks!

I was trying to apply the neural_tangents_cookbook notebook to a more complicated dataset to see how well it worked and I'm not sure if I'm doing something wrong because it's not working as well as I hoped.

I have a colab notebook here, which is basically a copy of the cookbook but with a different target function which has more data points and is more oscillatory.

Here is an example of the target function and the data points

Screenshot 2020-10-12 at 10 18 40

And here is the prediction of the network superimposed after training using the exact Bayesian inference method.

Screenshot 2020-10-12 at 10 19 29

I'm guessing that my network is just stuck in a local minimum or something. Or is there the notion of a length scale with NNGPs?

Thanks for any help!

SiuMath commented 4 years ago

Hi,

Thanks for reaching out.

I think this is due to the regularizer in

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs, train_ys, diag_reg=1e-4)

It prevents the predict_fn to interpolate the training data. If you set it to zero, it should be able to to do that.

Best, Lechao

On Oct 12, 2020, at 5:21 AM, Sebastian Khan notifications@github.com wrote:

Hi,

I'm not sure this is the write place to ask this kind of question but I see the issues are filled with questions so hope this is OK. Let me know if I should post this on SO or something (but I had a quick look and could see a tag for this).

I think this package is very interesting and the example notebooks you have make it much easier to learn how to use it so thanks!

I was trying to apply the neural_tangents_cookbook https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/neural_tangents_cookbook.ipynb notebook to a more complicated dataset to see how well it worked and I'm not sure if I'm doing something wrong because it's not working as well as I hoped.

I have a colab notebook here https://colab.research.google.com/drive/1CXyHLQ7vKSy-fJ7ZH0SYSREbfMnK23BX?usp=sharing, which is basically a copy of the cookbook but with a different target function which has more data points and is more oscillatory.

Here is an example of the target function and the data points

https://user-images.githubusercontent.com/13164315/95729141-59f68580-0c74-11eb-89d8-a92265c5e5cf.png And here is the prediction of the network superimposed after training using the exact Bayesian inference method.

https://user-images.githubusercontent.com/13164315/95729241-84484300-0c74-11eb-8987-9ad30f1469c4.png I'm guessing that my network is just stuck in a local minimum or something. Or is there the notion of a length scale with NNGPs?

Thanks for any help!

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/77, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGC3MA6RTZIUHRVTQK6W5JDSKLDC3ANCNFSM4SMRO2SA.

Cyberface commented 4 years ago

Hi LeChao,

Thanks for replying! I just tried reducing diag_reg and values smaller that 1e-6 give NaN. Using 1e-5 does give a more "noisy" mean function but the function approximation is still not so great. I tried smaller values and even zero but also get NaNs.

Any ideas?

For example this is what the curve looks like with 1e-5

Screenshot 2020-10-12 at 15 28 16

SiuMath commented 4 years ago

I think this is due to float precision. The least singular value of the NTK kernel is <= 1e-7 (see attached). Increasing the float precision to float64 may help. Also, using 1e-5, i got a slightly better plot.

On Oct 12, 2020, at 10:29 AM, Sebastian Khan notifications@github.com wrote:

Hi LeChao,

Thanks for replying! I just tried reducing diag_reg and values smaller that 1e-6 give NaN. Using 1e-5 does give a more "noisy" mean function but the function approximation is still not so great. I tried smaller values and even zero but also get NaNs.

Any ideas?

For example this is what the curve looks like with 1e-5

https://user-images.githubusercontent.com/13164315/95757863-97bcd380-0c9f-11eb-94d4-d19b1783d2d2.png — You are receiving this because you commented. Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/77#issuecomment-707155502, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGC3MA46C2XPD3YKF4KU7CDSKMHDZANCNFSM4SMRO2SA.

Cyberface commented 4 years ago

Ah! Great! It's a precision issue.

Setting JAX to appropriately

from jax.config import config
config.update("jax_enable_x64", True)

Now with diag_reg=1e-8 I get the following!

Screenshot 2020-10-12 at 16 05 27

Thanks so much for your rapid help!

P.S. Did you mean to attach a figure in your last post because I don't see it?