Open tengandreaxu opened 11 months ago
hi, how can i enabling float64 precision ?
Sorry for the late reply!
@zhangbububu see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision
@tengandreaxu could you try using Relu(do_stabilize=True)
? https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Relu.html This parameter triggers a way of calculating the nonlinearity kernel in a way that helps prevent numerical overflow.
Thank you so much, Roman. It's no problem at all!
import numpy as np
from neural_tangents import stax
from jax import jit
W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
layer_fn.append(stax.Relu(do_stabilize=True))
layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)
kernel_fn = jit(kernel_fn, static_argnames="get")
x = np.random.rand(100, 100)
print(kernel_fn(x, x, "ntk"))
results in
[[2.61008562e+20 1.12163820e+20 1.23732785e+20 ... 1.08229372e+20
1.05533967e+20 1.10687273e+20]
[1.12163820e+20 2.92078984e+20 1.31143308e+20 ... 1.16449180e+20
1.15616286e+20 1.19062657e+20]
[1.23732785e+20 1.31143308e+20 3.36093753e+20 ... 1.28641726e+20
1.19473708e+20 1.28997387e+20]
...
[1.08229363e+20 1.16449180e+20 1.28641726e+20 ... 2.74442324e+20
1.07858132e+20 1.20695995e+20]
[1.05533967e+20 1.15616286e+20 1.19473708e+20 ... 1.07858132e+20
2.69344883e+20 1.11830439e+20]
[1.10687273e+20 1.19062657e+20 1.28997387e+20 ... 1.20695995e+20
1.11830439e+20 2.83645061e+20]]
Do you think that there is no sense in having weights drawn from a higher standard deviation as we go deeper into the neural net in the infinite width?
@romanngg @tengandreaxu
hi, i meet a confuse problem
init_fn, apply_fn, kernel_fn = stax.serial(
stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
stax.Dense(1, W_std=1.5, b_std=0.05)
)
s = 10
l = jnp.pi * -s
r = jnp.pi * s
N_tr = 100
N_te = 5
train_xs = jnp.linspace(l, r , N_tr).reshape(-1, 1).astype(jnp.float64)
train_ys = jnp.sin(train_xs) + jnp.sin(2*train_xs).astype(jnp.float64)
test_xs = jnp.linspace(l, r, N_te).reshape(-1, 1).astype(jnp.float64)
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs,
train_ys, diag_reg=1e-4)
nkt_mean, nkt_covariance = predict_fn(x_test=test_xs, get='ntk',
compute_cov=True)
print(nkt_mean)
if i increate the number of training samples (N_tr), i will get a all NaN nkt_mean
@romanngg @tengandreaxu
hi, i meet a confuse problem
init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True), stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True), stax.Dense(1, W_std=1.5, b_std=0.05) ) s = 10 l = jnp.pi * -s r = jnp.pi * s N_tr = 100 N_te = 5 train_xs = jnp.linspace(l, r , N_tr).reshape(-1, 1).astype(jnp.float64) train_ys = jnp.sin(train_xs) + jnp.sin(2*train_xs).astype(jnp.float64) test_xs = jnp.linspace(l, r, N_te).reshape(-1, 1).astype(jnp.float64) predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs, train_ys, diag_reg=1e-4) nkt_mean, nkt_covariance = predict_fn(x_test=test_xs, get='ntk', compute_cov=True) print(nkt_mean)
if i increate the number of training samples (N_tr), i will get a all NaN
nkt_mean
@tengandreaxu
Do you think that there is no sense in having weights drawn from a higher standard deviation as we go deeper into the neural net in the infinite width?
I think so, ideally you would want the mean and variance of your network outputs to match the mean and variance of your training labels, as a sensible prior. But even if your training labels have a large variance, it's common practice to just standardize them (together with test labels) to have mean 0 and variance 1 for best numerical stability.
Then in a Relu network, to have mean zero / variance one outputs (given mean zero, variance one inputs), you would want to set W_std=2**0.5
for all intermediate layers preceding Relus, and W_std=1
for the top layer.
@zhangbububu replied in your separate thread, let's continue there.
Thank you for your prompt help Roman!
Hi everyone, thank you so much for your exceptional work!
I'm encountering some numerical issues when weights are drawn from Gaussians with a high standard deviation. Please see the snippet below:
The result achieves:
By enabling float64 precision, the results indicate numerical values blowing up:
What's interesting is that the behavior appears to be more dependent on the depth than the high values in the weights' standard deviation. If the standard deviation of the weights were reversed (by uncommenting the code), so that in layer 1 we would have $w_{ij} \sim \mathcal{N}(0,17)$, and so on so forth. The results would remain unchanged.
Thank you in advance, and happy new year!