google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Analytic kernel evaluated on sparse inputs #14

Closed liutianlin0121 closed 4 years ago

liutianlin0121 commented 4 years ago

Hi!

A bug seems to occur when I was trying to evaluate analytic NTKs using sparse input data -- the evaluated kernel contains nan entries. This can be reproduced with the following lines of codes:

from jax import random
from neural_tangents import stax

key = random.PRNGKey(1)

# a batch of dense inputs 
x_dense = random.normal(key, (3, 32, 32, 3))

# a batch of sparse inputs 
x_sparse = x_dense * (abs(x_dense) > 1.2)

# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
     stax.Conv(128, (3, 3)),
     stax.Relu(),
     stax.Flatten(),
     stax.Dense(10) )

# Evaluate the analytic NTK upon dense inputs

print('NTK evaluated w/ dense inputs: \n', kernel_fn(x_dense, x_dense, 'ntk')) # the outputs look fine.

print('\n')

# Evaluate the analytic NTK upon sparse inputs

print('NTK evaluated w/ sparse inputs: \n', kernel_fn(x_sparse, x_sparse, 'ntk')) # the outputs contains nan

The output of the above script should be:

NTK evaluated w/ dense inputs: 
 [[0.97102666 0.16131128 0.16714054]
 [0.16131128 0.9743941  0.17580226]
 [0.16714054 0.17580226 1.0097454 ]]

NTK evaluated w/ sparse inputs: 
 [[       nan        nan        nan]
 [       nan 0.66292834        nan]
 [       nan        nan        nan]]

Thanks for your time in advance!

sschoenholz commented 4 years ago

Thanks for the report! The issue is that if the variance of an input (here taken to mean a single pixel for a single datapoint) is zero then it can cause NaNs. There is a simple fix to just add a small stability term to our normalization. I expect us to push a fix today.

liutianlin0121 commented 4 years ago

Thanks so much for your reply!

I am not sure whether the input variance plays a major role here. Indeed, in the example above, the sparse inputs are drawn from a Gaussian distribution and then truncated based on magnitude, so their magnitudes should symmetrically spread out around 0. But I also found the similar nan problem with non-negative sparse inputs. The below script shows this phenomenon with MNIST images:


import tensorflow as tf
import numpy as np
from jax import random
from neural_tangents import stax

mnist = tf.keras.datasets.mnist

(x_train, _), (_, _) = mnist.load_data()

x_train = x_train / 255.0 # normalize the input values to values in (0, 1)

x_train_subset_sparse = x_train[:3].reshape([-1, 28, 28, 1]) # sparse input samples. 

# standardize the data 
mean = np.mean(x_train)
std = np.std(x_train)
x_train_dense = (x_train - mean) / std

x_train_subset_dense = x_train_dense[:3].reshape([-1, 28, 28, 1])  # dense input samples

# A CNN architecture
init_fn, apply_fn, kernel_fn = stax.serial(
     stax.Conv(128, (3, 3)),
     stax.Relu(),
     stax.Flatten(),
     stax.Dense(10) )

print('NTK evaluated w/ sparse MNIST images: \n', kernel_fn(x_train_subset_sparse, x_train_subset_sparse, 'ntk')) # the outputs contains nan

print('NTK evaluated w/ dense, standardized MNIST images: \n', kernel_fn(x_train_subset_dense, x_train_subset_dense, 'ntk')) # the outputs looks fine

The output of the above script should be like:

NTK evaluated w/ sparse MNIST images: 
 [[nan nan nan]
 [nan nan nan]
 [nan nan nan]]
NTK evaluated w/ dense, standardized MNIST images: 
 [[1.1637697  0.6116184  0.21783468]
 [0.6116184  1.3009455  0.213599  ]
 [0.21783468 0.213599   0.79291606]]

So, without standardization, the sparse MNIST images seem to cause the nan problem. The standardized MNIST images with zero mean actually seem to solve the problem.

Thanks for your time!

SiuMath commented 4 years ago

Thanks Tianlin.

The problem is caused by the fact that the per pixel variance is zero for many pixels in a single input of mnist (or sparse inputs). When computing the ntk of cnn, we need to keep track of the variance of each pixel in each input. You could be able to see this by computing the ntk right after the conv layer. There are many zero terms in the ntk. Subtracting the mean eliminates the zeros in the pixels and fixes this issue.

On Thu, Dec 19, 2019 at 2:48 PM Tianlin Liu notifications@github.com wrote:

Thanks so much for your reply!

I am not sure whether the input variance plays a major role here. Indeed, in the example above, the sparse inputs are drawn from a Gaussian distribution and then truncated based on magnitude, so their magnitudes should symmetrically spread out around 0. But I also found the similar nan problem with non-negative sparse inputs. The below script shows this phenomenon with MNIST images:

import tensorflow as tf import numpy as np from jax import random from neural_tangents import stax

mnist = tf.keras.datasets.mnist

(xtrain, ), (, ) = mnist.load_data()

x_train = x_train / 255.0 # normalize the input values to values in (0, 1)

x_train_subset_sparse = x_train[:3].reshape([-1, 28, 28, 1]) # sparse input samples.

standardize the data

mean = np.mean(x_train) std = np.std(x_train) x_train_dense = (x_train - mean) / std

x_train_subset_dense = x_train_dense[:3].reshape([-1, 28, 28, 1]) # dense input samples

A CNN architecture

init_fn, apply_fn, kernel_fn = stax.serial( stax.Conv(128, (3, 3)), stax.Relu(), stax.Flatten(), stax.Dense(10) )

print('NTK evaluated w/ sparse MNIST images: \n', kernel_fn(x_train_subset_sparse, x_train_subset_sparse, 'ntk')) # the outputs contains nan

print('NTK evaluated w/ dense, standardized MNIST images: \n', kernel_fn(x_train_subset_dense, x_train_subset_dense, 'ntk')) # the outputs looks fine

The output of the above script should be like:

NTK evaluated w/ sparse MNIST images: [[nan nan nan] [nan nan nan] [nan nan nan]] NTK evaluated w/ dense, standardized MNIST images: [[1.1637697 0.6116184 0.21783468] [0.6116184 1.3009455 0.213599 ] [0.21783468 0.213599 0.79291606]]

So, without standardization, the sparse MNIST images seem to cause the nan problem. The standardized MNIST images with zero mean actually seem to solve the problem.

Thanks for your time!

— You are receiving this because you are subscribed to this thread. Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/14?email_source=notifications&email_token=AGC3MA6VMCP6475YM7DL4RDQZPF2PA5CNFSM4J5CDCY2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEHKXE3Y#issuecomment-567636591, or unsubscribe https://github.com/notifications/unsubscribe-auth/AGC3MA7EJFVTGDOQXHY5YHTQZPF2PANCNFSM4J5CDCYQ .

liutianlin0121 commented 4 years ago

@SiuMath @sschoenholz Many thanks for your explanations! Previously I misunderstood the variance we are talking about here as the one defined across multiple samples for a single pixel :)

romanngg commented 4 years ago

FYI, I believe Sam has fixed it 0e92b0f!

liutianlin0121 commented 4 years ago

@romanngg many thanks!!