google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviation #197

Open tengandreaxu opened 11 months ago

tengandreaxu commented 11 months ago

Hi everyone, thank you so much for your exceptional work!

I'm encountering some numerical issues when weights are drawn from Gaussians with a high standard deviation. Please see the snippet below:

import numpy as np
from neural_tangents import stax
from jax import jit

W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
    layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
    layer_fn.append(stax.Relu())

layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)

kernel_fn = jit(kernel_fn, static_argnames="get")

x = np.random.rand(100, 100)

print(kernel_fn(x, x, "ntk"))

The result achieves:

[[nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 ...
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]
 [nan nan nan ... nan nan nan]]

By enabling float64 precision, the results indicate numerical values blowing up:

[[2.2293401e+18 9.3420067e+17 9.2034030e+17 ... 8.9008971e+17
  9.6801663e+17 9.6436509e+17]
 [9.3420067e+17 2.3730658e+18 9.4658846e+17 ... 9.6854199e+17
  9.6182735e+17 9.9944418e+17]
 [9.2034030e+17 9.4658846e+17 2.3106050e+18 ... 9.1702287e+17
  9.5415269e+17 9.9692925e+17]
 ...
 [8.9008971e+17 9.6854199e+17 9.1702300e+17 ... 2.2127619e+18
  9.2056034e+17 1.0147568e+18]
 [9.6801663e+17 9.6182728e+17 9.5415269e+17 ... 9.2056034e+17
  2.3979914e+18 9.9505658e+17]
 [9.6436488e+17 9.9944418e+17 9.9692925e+17 ... 1.0147568e+18
  9.9505658e+17 2.4954969e+18]]

What's interesting is that the behavior appears to be more dependent on the depth than the high values in the weights' standard deviation. If the standard deviation of the weights were reversed (by uncommenting the code), so that in layer 1 we would have $w_{ij} \sim \mathcal{N}(0,17)$, and so on so forth. The results would remain unchanged.

Thank you in advance, and happy new year!

zhangbububu commented 10 months ago

hi, how can i enabling float64 precision ?

romanngg commented 10 months ago

Sorry for the late reply!

@zhangbububu see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

@tengandreaxu could you try using Relu(do_stabilize=True)? https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Relu.html This parameter triggers a way of calculating the nonlinearity kernel in a way that helps prevent numerical overflow.

tengandreaxu commented 10 months ago

Thank you so much, Roman. It's no problem at all!

import numpy as np
from neural_tangents import stax
from jax import jit

W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
    layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
    layer_fn.append(stax.Relu(do_stabilize=True))

layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)

kernel_fn = jit(kernel_fn, static_argnames="get")

x = np.random.rand(100, 100)

print(kernel_fn(x, x, "ntk"))

results in

[[2.61008562e+20 1.12163820e+20 1.23732785e+20 ... 1.08229372e+20
  1.05533967e+20 1.10687273e+20]
 [1.12163820e+20 2.92078984e+20 1.31143308e+20 ... 1.16449180e+20
  1.15616286e+20 1.19062657e+20]
 [1.23732785e+20 1.31143308e+20 3.36093753e+20 ... 1.28641726e+20
  1.19473708e+20 1.28997387e+20]
 ...
 [1.08229363e+20 1.16449180e+20 1.28641726e+20 ... 2.74442324e+20
  1.07858132e+20 1.20695995e+20]
 [1.05533967e+20 1.15616286e+20 1.19473708e+20 ... 1.07858132e+20
  2.69344883e+20 1.11830439e+20]
 [1.10687273e+20 1.19062657e+20 1.28997387e+20 ... 1.20695995e+20
  1.11830439e+20 2.83645061e+20]]

Do you think that there is no sense in having weights drawn from a higher standard deviation as we go deeper into the neural net in the infinite width?

zhangbububu commented 10 months ago

@romanngg @tengandreaxu

hi, i meet a confuse problem

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(1, W_std=1.5, b_std=0.05)
)

s = 10
l = jnp.pi * -s
r = jnp.pi * s 
N_tr = 100
N_te = 5
train_xs = jnp.linspace(l, r , N_tr).reshape(-1, 1).astype(jnp.float64)
train_ys = jnp.sin(train_xs) + jnp.sin(2*train_xs).astype(jnp.float64)
test_xs = jnp.linspace(l, r, N_te).reshape(-1, 1).astype(jnp.float64)

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs,
                                                      train_ys, diag_reg=1e-4)
nkt_mean, nkt_covariance = predict_fn(x_test=test_xs, get='ntk',
                                        compute_cov=True)
print(nkt_mean)

if i increate the number of training samples (N_tr), i will get a all NaN nkt_mean

zhangbububu commented 10 months ago

@romanngg @tengandreaxu

hi, i meet a confuse problem

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(1, W_std=1.5, b_std=0.05)
)

s = 10
l = jnp.pi * -s
r = jnp.pi * s 
N_tr = 100
N_te = 5
train_xs = jnp.linspace(l, r , N_tr).reshape(-1, 1).astype(jnp.float64)
train_ys = jnp.sin(train_xs) + jnp.sin(2*train_xs).astype(jnp.float64)
test_xs = jnp.linspace(l, r, N_te).reshape(-1, 1).astype(jnp.float64)

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs,
                                                      train_ys, diag_reg=1e-4)
nkt_mean, nkt_covariance = predict_fn(x_test=test_xs, get='ntk',
                                        compute_cov=True)
print(nkt_mean)

if i increate the number of training samples (N_tr), i will get a all NaN nkt_mean

image image
romanngg commented 9 months ago

@tengandreaxu

Do you think that there is no sense in having weights drawn from a higher standard deviation as we go deeper into the neural net in the infinite width?

I think so, ideally you would want the mean and variance of your network outputs to match the mean and variance of your training labels, as a sensible prior. But even if your training labels have a large variance, it's common practice to just standardize them (together with test labels) to have mean 0 and variance 1 for best numerical stability.

Then in a Relu network, to have mean zero / variance one outputs (given mean zero, variance one inputs), you would want to set W_std=2**0.5 for all intermediate layers preceding Relus, and W_std=1 for the top layer.

@zhangbububu replied in your separate thread, let's continue there.

tengandreaxu commented 9 months ago

Thank you for your prompt help Roman!