Open Amir-Arsalan opened 4 years ago
Thanks for raising this question and for the clear repro. I haven't yet looked into the [0,1] issue, but I have investigated the NaNs. Note that for deep Erf networks with the parameters you used is in what is known as the "ordered" phase and here the kernel tends to be close to singular. At low numerical precision this can frequently lead to NaNs during inference (we discuss these issues a little bit here).
In any case, you can get rid of all the NaNs in this example by enabling float64 mode. To do this call from jax.config import config ; config.update('jax_enable_x64', True)
before you've run any other jax code.
Let me know if this helps!
Best, Sam
@sschoenholz Thanks Sam, I didn't know that this particular choice of model would provide a kernel that is close to singular. Using float64 helps. However, the std issue is still there. The std/var values are actually not always between (0, 1). They sometimes might go a little bit above 1 (e.g. 1.05) but those values are certainly not the correct values for std/var.
HI Amir-Arsalan. Thanks for your question! I wonder, are you sure the std/var values should be larger? Remember that the variance of a posterior is never larger than the variance of the prior (by the law of total variance: https://en.wikipedia.org/wiki/Law_of_total_variance ).
For the NNGP, where the predictions correspond to a Bayesian posterior, the variance of your posterior is upper bounded by the variance of your prior, which is upper bounded by the sum of sigma_w^2 and sigma_b^2 in your readout layer (since the previous layer activations are bounded in (-1,1) due to the erf). So the largest possible value the variance could take on is 1.5^2 + 0.05^2 = 2.25. Because this is an upper bound on the prior variance, and because your training datapoints are fairly dense which will significantly reduce the posterior variance, posterior variance values between 0 and 1 seem reasonable to me.
Hello, I ran into a similar issue, so I was enabling float64 mode as suggested above. This gets rid of the NaNs; however, the empirical NTK computations now seem about x8 times slower (across multiple scales of training set sizes).
Any ideas on why this might be / if this is expected? Is this due to some underlying matrix computations' convergence?
Regarding precision issues with inference and why higher precision helps, we have some more discussion in https://arxiv.org/abs/2007.15801 section D.
Regarding x64 NTK being x8 slower than x32, I could reproduce it on Tesla P4, where x32 is ~10-20x faster than x64: x32: https://colab.research.google.com/gist/romanngg/944a300ce63de3db147ca45a67e70c6d/issue_36_x32.ipynb x64: https://colab.research.google.com/gist/romanngg/ffdad170615058b91cc32ea08cf1d054/issue_36_x64.ipynb
However, note that per the last measurement just the forward pass of a neural network apply_fn
is also 10x faster in x32 precision. On a Tesla V100, the empirical NTK and forward pass were both just 3x faster in x32, which seems reasonable due to multiplication cost scaling non-linearly with the number of bits https://en.wikipedia.org/wiki/Computational_complexity_of_mathematical_operations.
So while such big difference in performance is unexpected to me, I am also not sure what should be the baseline expectation, and since apply_fn
is also dramatically slower on P4 in x64 than in x32, it may be better to submit a bug with https://github.com/google/jax as it appears not specific to NT. But ofc if you have a setting where specifically NT is dramatically slow in x64, but other computations are not, please let us know!
A few other suggestions on performance:
vmap_axes=0
argument - https://neural-tangents.readthedocs.io/en/latest/neural_tangents.empirical.html setting it may help dramatically speed up the empirical NTK in general (especially for large training set sizes).
I am trying to use NNGP/NTK to fit outputs of a black-box function. The y axis of my data has a pretty wide range (e.g. [x, y] where x could be as low as a large negative number and y could be as high as 20000). When I tried to use NNGP/NTK to find a suitable kernel I realized that I get lots of NaNs as standard deviation. When I looked at the [co]variance values I realized that 1) they are super small (e.g. 1e-6) and 2) they are sometimes negative which results in NaN standard deviation values. Also, it would be very likely (or almost certain) that I will get all NaNs for covariance if I set
diag_reg
to anything below1e-3
. Why is that?Additionally, I learned the range of std/covariance is [0, 1] which is not correct but the means seem to be correct. I think this should be a bug (relevant to this) and it's possible that the normalization/unnormalization steps have not been implemented properly.
Below I wrote some code that shows these issues:
And here's the output:
If you set
diag_reg
to anything like 1e-3 or lower you'll get NaNs for everything: