google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

Why is the standard deviation always within [0, 1] and why do I get negative or NaN covariance values? #36

Open Amir-Arsalan opened 4 years ago

Amir-Arsalan commented 4 years ago

I am trying to use NNGP/NTK to fit outputs of a black-box function. The y axis of my data has a pretty wide range (e.g. [x, y] where x could be as low as a large negative number and y could be as high as 20000). When I tried to use NNGP/NTK to find a suitable kernel I realized that I get lots of NaNs as standard deviation. When I looked at the [co]variance values I realized that 1) they are super small (e.g. 1e-6) and 2) they are sometimes negative which results in NaN standard deviation values. Also, it would be very likely (or almost certain) that I will get all NaNs for covariance if I set diag_reg to anything below 1e-3. Why is that?

Additionally, I learned the range of std/covariance is [0, 1] which is not correct but the means seem to be correct. I think this should be a bug (relevant to this) and it's possible that the normalization/unnormalization steps have not been implemented properly.

Below I wrote some code that shows these issues:

from jax import random
import jax.numpy as np
import neural_tangents as nt
from neural_tangents import stax

key = random.PRNGKey(10)

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(511, W_std=1.5, b_std=0.05), stax.Erf(),
    stax.Dense(1, W_std=1.5, b_std=0.05)
)

train_xs = np.array([0.0000, 0.0200, 0.1000, 0.1200, 0.1400, 0.1600,
        0.1800, 0.2000, 0.2200, 0.2400, 0.2600, 0.3400,
        0.3600, 0.3800, 0.4000, 0.4200, 0.4400, 0.4600, 0.4800, 0.5000, 0.5200,
        0.5400, 0.5600, 0.5800, 0.6000, 0.6200, 0.6400, 0.6600, 0.6800, 0.7000,
        0.8000, 0.8200, 0.8400, 0.8600, 0.8800,
        0.9000, 0.9200, 0.9400, 0.9600, 0.9800, 1.0000, 1.0200, 1.0400, 1.0600,
        1.0800, 1.1000, 1.1200, 1.1400, 1.1600, 1.1800, 1.2000, 1.2200, 1.2400,
        1.2600, 1.2800, 1.3000, 1.3200, 1.3400, 1.3600, 1.3800, 1.4000, 1.4200,
        1.4400, 1.4600, 1.4800, 1.5000, 1.5200, 1.5400, 1.5600, 1.5800, 1.6000,
        1.6200, 1.6400, 1.6600, 1.6800, 1.7000, 1.7200, 1.7400, 1.7600, 1.7800,
        1.8000, 1.8200, 1.8400, 1.8600, 1.8800, 1.9000, 1.9200, 1.9400, 1.9600,
        1.9800, 2.0000, 2.0200, 2.0400, 2.0600, 2.0800, 2.1000, 2.1200, 2.1400]).reshape(-1, 1)

train_ys = np.array([0.1811, 0.1755, 0.0703, 0.0458, 0.0321, 0.0281,
        0.0314, 0.0574, 0.1113, 0.1680, 0.2007, 0.1864,
        0.1542, 0.1240, 0.1012, 0.0931, 0.0928, 0.0932, 0.0932, 0.0993, 0.1158,
        0.1359, 0.1524, 0.1587, 0.1610, 0.1610, 0.1610, 0.1610, 0.1610, 0.1610,
        0.1610, 0.1610, 0.1610, 0.1610, 0.1610,
        0.1610, 0.1610, 0.1610, 0.1610, 0.1705, 0.1995, 0.2493, 0.3048, 0.3482,
        0.3758, 0.3815, 0.3814, 0.3749, 0.3580, 0.3358, 0.3246, 0.3220, 0.3232,
        0.3352, 0.3619, 0.4008, 0.4347, 0.4507, 0.4541, 0.4534, 0.4461, 0.4272,
        0.4089, 0.4031, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025,
        0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025, 0.4025,
        0.4025, 0.4025, 0.4110, 0.4515, 0.5125, 0.5915, 0.6517, 0.6986, 0.7209,
        0.7261, 0.7246, 0.7246, 0.7232, 0.7122, 0.6844, 0.6524, 0.6344, 0.6308]).reshape(-1, 1)*1000
test_xs = np.linspace(0., 3.4, 70).reshape(-1, 1)

mean, covariance = nt.predict.gp_inference(kernel_fn, train_xs, train_ys, test_xs, get='ntk', diag_reg=1e-2, compute_cov=True) #you can also try get='nngp'

mean = np.reshape(mean, (-1,))
std = np.sqrt(np.diag(covariance))
print (mean)
print ('\n')
print (std) # you will get some NaNs and all stds are within (0, 1)

And here's the output:

[0.02037075 1.0244607  0.071396   0.07778245 0.01721494 0.02377584
 0.2626345  0.01136238 0.01295557 0.00731608 0.00607905 0.00560992
        nan 0.01006025 0.01062503 0.06651297 0.02710511 0.01521527
        nan        nan 0.00918696 0.01167288 0.00146484 0.00718454
 0.00580829 0.0038602  0.00803071        nan 0.00358812        nan
        nan 0.00651448 0.00179406        nan 0.00851347        nan
 0.01051223 0.00838651        nan 0.00743728 0.00571519        nan
        nan        nan 0.02975312 0.08047054 0.1433592  0.22170994
 0.3043489  0.38526937 0.46138063 0.5305047  0.591389   0.6448372
 0.6911122  0.73078007 0.76483166 0.7942772  0.819495   0.8412284
 0.859998   0.8762823  0.89045817 0.90289694 0.9137593  0.9233606
 0.9318279  0.9393407  0.94605935 0.95205545]

[177.5949     8.138118  68.319824  29.864521  51.93844  181.89476
 188.30295  178.64008  106.595825  92.69229   96.36267  137.6497
 160.45253  160.79393  160.80162  159.1641   160.38309  160.61182
 162.00093  157.15422  181.85912  287.67047  375.76706  377.4702
 334.8078   321.11066  371.47156  436.5739   451.28027  424.98627
 402.5257   397.91467  401.8235   406.74292  405.56165  393.05548
 383.6278   421.41656  512.12476  630.0038   710.7755   736.6554
 702.9308   640.1339   573.10583  510.43176  456.05994  409.10168
 367.6361   332.77747  300.55188  272.0462   247.56     225.53345
 206.23566  189.16211  174.04999  160.7716   148.92923  138.42578
 129.2586   120.928314 113.518555 106.788605 100.86783   95.36664
  90.43851   86.002396  81.907715  78.196014]

If you set diag_reg to anything like 1e-3 or lower you'll get NaNs for everything:

[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]

[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]
sschoenholz commented 4 years ago

Thanks for raising this question and for the clear repro. I haven't yet looked into the [0,1] issue, but I have investigated the NaNs. Note that for deep Erf networks with the parameters you used is in what is known as the "ordered" phase and here the kernel tends to be close to singular. At low numerical precision this can frequently lead to NaNs during inference (we discuss these issues a little bit here).

In any case, you can get rid of all the NaNs in this example by enabling float64 mode. To do this call from jax.config import config ; config.update('jax_enable_x64', True) before you've run any other jax code.

Let me know if this helps!

Best, Sam

Amir-Arsalan commented 4 years ago

@sschoenholz Thanks Sam, I didn't know that this particular choice of model would provide a kernel that is close to singular. Using float64 helps. However, the std issue is still there. The std/var values are actually not always between (0, 1). They sometimes might go a little bit above 1 (e.g. 1.05) but those values are certainly not the correct values for std/var.

Sohl-Dickstein commented 4 years ago

HI Amir-Arsalan. Thanks for your question! I wonder, are you sure the std/var values should be larger? Remember that the variance of a posterior is never larger than the variance of the prior (by the law of total variance: https://en.wikipedia.org/wiki/Law_of_total_variance ).

For the NNGP, where the predictions correspond to a Bayesian posterior, the variance of your posterior is upper bounded by the variance of your prior, which is upper bounded by the sum of sigma_w^2 and sigma_b^2 in your readout layer (since the previous layer activations are bounded in (-1,1) due to the erf). So the largest possible value the variance could take on is 1.5^2 + 0.05^2 = 2.25. Because this is an upper bound on the prior variance, and because your training datapoints are fairly dense which will significantly reduce the posterior variance, posterior variance values between 0 and 1 seem reasonable to me.

sung-max commented 4 years ago

Hello, I ran into a similar issue, so I was enabling float64 mode as suggested above. This gets rid of the NaNs; however, the empirical NTK computations now seem about x8 times slower (across multiple scales of training set sizes).

Any ideas on why this might be / if this is expected? Is this due to some underlying matrix computations' convergence?

romanngg commented 3 years ago

However, note that per the last measurement just the forward pass of a neural network apply_fn is also 10x faster in x32 precision. On a Tesla V100, the empirical NTK and forward pass were both just 3x faster in x32, which seems reasonable due to multiplication cost scaling non-linearly with the number of bits https://en.wikipedia.org/wiki/Computational_complexity_of_mathematical_operations.

So while such big difference in performance is unexpected to me, I am also not sure what should be the baseline expectation, and since apply_fn is also dramatically slower on P4 in x64 than in x32, it may be better to submit a bug with https://github.com/google/jax as it appears not specific to NT. But ofc if you have a setting where specifically NT is dramatically slow in x64, but other computations are not, please let us know!

A few other suggestions on performance: