google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Prediction generates nan values #76

Open Kangfei opened 4 years ago

Kangfei commented 4 years ago

I use a simple NN model and gradient_descent_mse_ensemble to train the kernel as follows:

init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(512), stax.Relu(), stax.Dense(512), stax.Relu(), stax.Dense(1) ) predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, X_train, Y_train, diag_reg=1e-3)

On a regression of simple dataset, e.g., iris, the prediction result looks normal for nngp kernel but is full-nan for ntk kernel. For a complex dataset, the 'nngp' will also generates nan prediction, if I add one Dense and Relu layer for the NN. I'm so curious about when/why prediction may has nan values and how to debug it.

By the way, can I save a pre-trained model and then load it from disk, as tf or torch does? Thanks in advance and looking forward for replies.

ybj14 commented 4 years ago

In my opinion, changing ReLU to Erf might help. For infinite width model you can only save the kernel matrix or intermediate results (like Cholesky factors), but for finite NN, surely you can save the parameters.

SiuMath commented 4 years ago

When the kernel matrix kdd (the training-training kernel) has small/zero eigenvalues, the prediction may have full-nan. This usually happens when (1) CNN + global average pooling is used and (2) NNGP is used and (3) the dataset size is large (>=10k) In your case, changing stax.Dense(512), stax.Relu() to stax.Dense(512, W_std=np.sqrt(2.), b_std=0.1), stax.Relu() may help. The default W_std is 1. and for Relu this hyperparameter makes the norm of the activations decay by a factor of 2 in each Dense Layer. As @ybj14 pointed out, changing Relu to Erf may help in this case. Increasing the float precision may also help.

Let me know if this helps solve the problem.

Best, Lechao

Kangfei commented 4 years ago

When the kernel matrix kdd (the training-training kernel) has small/zero eigenvalues, the prediction may have full-nan. This usually happens when (1) CNN + global average pooling is used and (2) NNGP is used and (3) the dataset size is large (>=10k) In your case, changing stax.Dense(512), stax.Relu() to stax.Dense(512, W_std=np.sqrt(2.), b_std=0.1), stax.Relu() may help. The default W_std is 1. and for Relu this hyperparameter makes the norm of the activations decay by a factor of 2 in each Dense Layer. As @ybj14 pointed out, changing Relu to Erf may help in this case. Increasing the float precision may also help.

Let me know if this helps solve the problem.

Best, Lechao

Thanks for your reply. I have tried these tricks and each one is helpful to some degree. Using all of them already solves my problem!

Best, Kangfei