Closed liutianlin0121 closed 4 years ago
Thanks for the report! The issue is that if the variance of an input (here taken to mean a single pixel for a single datapoint) is zero then it can cause NaNs. There is a simple fix to just add a small stability term to our normalization. I expect us to push a fix today.
Thanks so much for your reply!
I am not sure whether the input variance plays a major role here. Indeed, in the example above, the sparse inputs are drawn from a Gaussian distribution and then truncated based on magnitude, so their magnitudes should symmetrically spread out around 0. But I also found the similar nan
problem with non-negative sparse inputs. The below script shows this phenomenon with MNIST images:
import tensorflow as tf
import numpy as np
from jax import random
from neural_tangents import stax
mnist = tf.keras.datasets.mnist
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0 # normalize the input values to values in (0, 1)
x_train_subset_sparse = x_train[:3].reshape([-1, 28, 28, 1]) # sparse input samples.
# standardize the data
mean = np.mean(x_train)
std = np.std(x_train)
x_train_dense = (x_train - mean) / std
x_train_subset_dense = x_train_dense[:3].reshape([-1, 28, 28, 1]) # dense input samples
# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Conv(128, (3, 3)),
stax.Relu(),
stax.Flatten(),
stax.Dense(10) )
print('NTK evaluated w/ sparse MNIST images: \n', kernel_fn(x_train_subset_sparse, x_train_subset_sparse, 'ntk')) # the outputs contains nan
print('NTK evaluated w/ dense, standardized MNIST images: \n', kernel_fn(x_train_subset_dense, x_train_subset_dense, 'ntk')) # the outputs looks fine
The output of the above script should be like:
NTK evaluated w/ sparse MNIST images:
[[nan nan nan]
[nan nan nan]
[nan nan nan]]
NTK evaluated w/ dense, standardized MNIST images:
[[1.1637697 0.6116184 0.21783468]
[0.6116184 1.3009455 0.213599 ]
[0.21783468 0.213599 0.79291606]]
So, without standardization, the sparse MNIST images seem to cause the nan
problem. The standardized MNIST images with zero mean actually seem to solve the problem.
Thanks for your time!
Thanks Tianlin.
The problem is caused by the fact that the per pixel variance is zero for many pixels in a single input of mnist (or sparse inputs). When computing the ntk of cnn, we need to keep track of the variance of each pixel in each input. You could be able to see this by computing the ntk right after the conv layer. There are many zero terms in the ntk. Subtracting the mean eliminates the zeros in the pixels and fixes this issue.
On Thu, Dec 19, 2019 at 2:48 PM Tianlin Liu notifications@github.com wrote:
Thanks so much for your reply!
I am not sure whether the input variance plays a major role here. Indeed, in the example above, the sparse inputs are drawn from a Gaussian distribution and then truncated based on magnitude, so their magnitudes should symmetrically spread out around 0. But I also found the similar nan problem with non-negative sparse inputs. The below script shows this phenomenon with MNIST images:
import tensorflow as tf import numpy as np from jax import random from neural_tangents import stax
mnist = tf.keras.datasets.mnist
(xtrain, ), (, ) = mnist.load_data()
x_train = x_train / 255.0 # normalize the input values to values in (0, 1)
x_train_subset_sparse = x_train[:3].reshape([-1, 28, 28, 1]) # sparse input samples.
standardize the data
mean = np.mean(x_train) std = np.std(x_train) x_train_dense = (x_train - mean) / std
x_train_subset_dense = x_train_dense[:3].reshape([-1, 28, 28, 1]) # dense input samples
A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial( stax.Conv(128, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(10) )
print('NTK evaluated w/ sparse MNIST images: \n', kernel_fn(x_train_subset_sparse, x_train_subset_sparse, 'ntk')) # the outputs contains nan
print('NTK evaluated w/ dense, standardized MNIST images: \n', kernel_fn(x_train_subset_dense, x_train_subset_dense, 'ntk')) # the outputs looks fine
The output of the above script should be like:
NTK evaluated w/ sparse MNIST images: [[nan nan nan] [nan nan nan] [nan nan nan]] NTK evaluated w/ dense, standardized MNIST images: [[1.1637697 0.6116184 0.21783468] [0.6116184 1.3009455 0.213599 ] [0.21783468 0.213599 0.79291606]]
So, without standardization, the sparse MNIST images seem to cause the nan problem. The standardized MNIST images with zero mean actually seem to solve the problem.
Thanks for your time!
— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/14?email_source=notifications&email_token=AGC3MA6VMCP6475YM7DL4RDQZPF2PA5CNFSM4J5CDCY2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEHKXE3Y#issuecomment-567636591, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGC3MA7EJFVTGDOQXHY5YHTQZPF2PANCNFSM4J5CDCYQ .
@SiuMath @sschoenholz Many thanks for your explanations! Previously I misunderstood the variance we are talking about here as the one defined across multiple samples for a single pixel :)
FYI, I believe Sam has fixed it 0e92b0f!
@romanngg many thanks!!
Hi!
A bug seems to occur when I was trying to evaluate analytic NTKs using sparse input data -- the evaluated kernel contains
nan
entries. This can be reproduced with the following lines of codes:The output of the above script should be:
Thanks for your time in advance!