Open bangxiangyong opened 1 year ago
On the phone, but one quick thought - could it be related to labels being exactly 1/-1, which are outside of the image of ERF / apply_fn (-1; 1)?
On Wed, Nov 15, 2023, 17:20 bangxiangyong @.***> wrote:
The NN with erf function output activation can occassionally output way beyond the boundary [-1,1]:
from jax import random from neural_tangents import stax import neural_tangents as nt import random as rd
init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(1), stax.Relu(), stax.Dense(1), stax.Relu(), stax.Dense(1), stax.Relu(), stax.Dense(1), stax.Erf(), )
key1, key2 = random.split(random.PRNGKey(777)) x1 = random.normal(key1, (100, 10)) x2 = random.normal(key2, (100, 10))
x_train, x_test = x1, x2 y_train = [rd.choice([-1, 1]) for i in range(100)] y_train = np.array(y_train)[:, np.newaxis]
predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train) y_test_nngp = predict_fn(x_test=x_test, get="nngp")
print(y_test_nngp.max()) ## 1.6560178 print(y_test_nngp.min()) ## -2.244388
Is this intended or have i missed something?
— Reply to this email directly, view it on GitHub https://github.com/google/neural-tangents/issues/191, or unsubscribe https://github.com/notifications/unsubscribe-auth/AKTTJJIBICIDANL4IUVMLIDYET2XJAVCNFSM6AAAAAA7M26XEGVHI2DSMVQWIX3LMV43ASLTON2WKOZRHE4TKMRRGEYDCOI . You are receiving this because you are subscribed to this thread.Message ID: @.***>
I don't have a good answer yet, but it appears that bad conditioning of kernel_fn(x_train, x_train).nngp
matrix (which is inverted to make predictions) is causing the numerical issues. One way to improve it is to have higher-dimensional inputs (e.g. have 1000 features vs 10, now input covariance is rank-10, and it appears to result in a badly conditioned output covariance), and/or pass a diag_reg=1e-3
(or other vaues) when calling gradient_descent_mse_ensemble
to add a small diagonal matrix to kernel_fn(x_train, x_train).nngp
before inversion.
The NN with erf function output activation can occassionally output way beyond the boundary [-1,1]:
Is this intended or have i missed something?