google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 226 forks source link

Question: simple example poor performance, what am I doing wrong? #163

Open mfouesneau opened 2 years ago

mfouesneau commented 2 years ago

Dear team, great package, I'm very excited to use it.

However, I tried a simple case, and I failed miserably to get a decent performance.

I generate a multi-dimensional dataset with a relatively simple feature

import numpy as np

#Create some fake data
np.random.seed(0)
m = 1000
n = 10
noise_std = 1.
X = 80*numpy.random.uniform(size=(m,n)) - 40
y = np.abs(X[:,6] - 4.0) + noise_std * np.random.normal(size=m)

And I followed your examples as

import neural_tangents as nt
from neural_tangents import stax
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(
    X, y.reshape(-1, 1), test_size=0.4, random_state=42)

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(256), stax.Relu(),
    stax.Dense(1)
)
predict_fn = nt.predict.gradient_descent_mse_ensemble(
    kernel_fn, 
    x_train,
    y_train)

# Unpack the predictions namedtuple
y_test_nngp, y_test_ntk = predict_fn(x_test=x_test, get=('nngp', 'ntk'), compute_cov=True)

Visual inspection shows terrible predictions, and loss values are large:

loss = lambda ypred, y_hat: 0.5 * jnp.mean((ypred - y_hat) ** 2)
print("loss_nngp = {}".format(loss(y_test_nngp.mean, y_test)))
print("loss_ntk = {}".format(loss(y_test_ntk.mean, y_test)))
loss_nngp = 6.877374649047852
loss_ntk = 6.610106468200684

I varied the network in many ways and fiddled with learning_rate and diag_reg, but I hardly changed anything.

I'm sure I am doing something wrong, but I cannot see what it is. Any obvious mistake?

Thanks for your help.

romanngg commented 2 years ago

At a glance, library usage seems good to me! Perhaps one way to figure this out is to establish a baseline using some other method (kernel, neural network, etc), to figure out what loss values are expected? For example it seems that y will have a mean of 20 (expectation of the absolute value of a uniform from -44 to 36 is 1/2 (40 + 4) / 2 + 1/2 (40 - 4) / 2), so scale of outputs is pretty large, so it's not obvious to me if the loss values are that large. Another angle is to try increasing the training set size - it's hard to say if 600 training points is large enough for the model to learn well.

mfouesneau commented 2 years ago
args = np.argsort(x_test[:, 6])
y_mean = np.reshape(y_test_ntk.mean, (-1,))[args]
y_std = np.sqrt(np.diag(y_test_ntk.covariance))[args]

plt.plot(X[:,6],y,'k.', alpha=0.1, rasterized=True)
plt.fill_between(
    np.reshape(x_test[args, 6], (-1)),
    y_mean - 3 * y_std,
    y_mean +  3 * y_std,
    color='red', alpha=0.2)
plt.xlabel('x_6')
plt.ylabel('y')

image

The thing is that changing the layer from 50 nodes to 5000 hardly changes the output. I would expect at least some changes.

I tried 10_000 points, and I only gained a factor of 2 on the loss image

Is there any guidance on what a correct training set should be?

mfouesneau commented 2 years ago

I get memory errors if I try 100 000 points in my dataset. Even with the batch trick

kernel_fn = nt.batch(kernel_fn,
                     device_count=0,
                     batch_size=1_000)
romanngg commented 2 years ago

Note that in your example you are doing inference with an infinitely-wide neural network (kernel_fn), so the width doesn't matter in this case. Also, the plot does look like the learned function mimicks |x_6 - 4| (at least it's not doing something obviously wrong, it has the right shape and kink location), so I'm inclined to think that it's working as intended?...

Re training set, I think it's constructed correctly, I'm just not sure how to reason about the generalization that we should expect from it (per your plot, it seems to be at least OKish?...).

And yes, 100K is too much for most GPUs.

mfouesneau commented 2 years ago

You're right; it seems to be doing ok, but with serious overfit.

Is there a paper to read to get a feeling for appropriate network architecture? My understanding is that multiplying layers will not change anything unless a "layer" is a complex thing already. right?

image image