google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Erf function goes beyond [-1,1] #191

Open bangxiangyong opened 1 year ago

bangxiangyong commented 1 year ago

The NN with erf function output activation can occassionally output way beyond the boundary [-1,1]:

from jax import random
from neural_tangents import stax
import neural_tangents as nt
import random as rd

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(1),
    stax.Relu(),
    stax.Dense(1),
    stax.Relu(),
    stax.Dense(1),
    stax.Relu(),
    stax.Dense(1),
    stax.Erf(),
)

key1, key2 = random.split(random.PRNGKey(777))
x1 = random.normal(key1, (100, 10))
x2 = random.normal(key2, (100, 10))

x_train, x_test = x1, x2
y_train = [rd.choice([-1, 1]) for i in range(100)]
y_train = np.array(y_train)[:, np.newaxis]

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train)
y_test_nngp = predict_fn(x_test=x_test, get="nngp")

print(y_test_nngp.max()) ## 1.6560178
print(y_test_nngp.min())  ## -2.244388

Is this intended or have i missed something?

romanngg commented 1 year ago

On the phone, but one quick thought - could it be related to labels being exactly 1/-1, which are outside of the image of ERF / apply_fn (-1; 1)?

On Wed, Nov 15, 2023, 17:20 bangxiangyong @.***> wrote:

The NN with erf function output activation can occassionally output way beyond the boundary [-1,1]:

from jax import random from neural_tangents import stax import neural_tangents as nt import random as rd

init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(1), stax.Relu(), stax.Dense(1), stax.Relu(), stax.Dense(1), stax.Relu(), stax.Dense(1), stax.Erf(), )

key1, key2 = random.split(random.PRNGKey(777)) x1 = random.normal(key1, (100, 10)) x2 = random.normal(key2, (100, 10))

x_train, x_test = x1, x2 y_train = [rd.choice([-1, 1]) for i in range(100)] y_train = np.array(y_train)[:, np.newaxis]

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train) y_test_nngp = predict_fn(x_test=x_test, get="nngp")

print(y_test_nngp.max()) ## 1.6560178 print(y_test_nngp.min()) ## -2.244388

Is this intended or have i missed something?

— Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/191, or unsubscribe https://github.com/notifications/unsubscribe-auth/AKTTJJIBICIDANL4IUVMLIDYET2XJAVCNFSM6AAAAAA7M26XEGVHI2DSMVQWIX3LMV43ASLTON2WKOZRHE4TKMRRGEYDCOI . You are receiving this because you are subscribed to this thread.Message ID: @.***>

romanngg commented 1 year ago

I don't have a good answer yet, but it appears that bad conditioning of kernel_fn(x_train, x_train).nngp matrix (which is inverted to make predictions) is causing the numerical issues. One way to improve it is to have higher-dimensional inputs (e.g. have 1000 features vs 10, now input covariance is rank-10, and it appears to result in a badly conditioned output covariance), and/or pass a diag_reg=1e-3 (or other vaues) when calling gradient_descent_mse_ensemble to add a small diagonal matrix to kernel_fn(x_train, x_train).nngp before inversion.