Open Kangfei opened 4 years ago
In my opinion, changing ReLU
to Erf
might help. For infinite width model you can only save the kernel matrix or intermediate results (like Cholesky factors), but for finite NN, surely you can save the parameters.
When the kernel matrix kdd (the training-training kernel) has small/zero eigenvalues, the prediction may have full-nan. This usually happens when (1) CNN + global average pooling is used and (2) NNGP is used and (3) the dataset size is large (>=10k)
In your case, changing stax.Dense(512), stax.Relu()
to stax.Dense(512, W_std=np.sqrt(2.), b_std=0.1), stax.Relu()
may help. The default W_std
is 1. and for Relu this hyperparameter makes the norm of the activations decay by a factor of 2 in each Dense Layer. As @ybj14 pointed out, changing Relu to Erf may help in this case. Increasing the float precision may also help.
Let me know if this helps solve the problem.
Best, Lechao
When the kernel matrix kdd (the training-training kernel) has small/zero eigenvalues, the prediction may have full-nan. This usually happens when (1) CNN + global average pooling is used and (2) NNGP is used and (3) the dataset size is large (>=10k) In your case, changing
stax.Dense(512), stax.Relu()
tostax.Dense(512, W_std=np.sqrt(2.), b_std=0.1), stax.Relu()
may help. The defaultW_std
is 1. and for Relu this hyperparameter makes the norm of the activations decay by a factor of 2 in each Dense Layer. As @ybj14 pointed out, changing Relu to Erf may help in this case. Increasing the float precision may also help.Let me know if this helps solve the problem.
Best, Lechao
Thanks for your reply. I have tried these tricks and each one is helpful to some degree. Using all of them already solves my problem!
Best, Kangfei
I use a simple NN model and gradient_descent_mse_ensemble to train the kernel as follows:
init_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(512), stax.Relu(), stax.Dense(512), stax.Relu(), stax.Dense(1) ) predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, X_train, Y_train, diag_reg=1e-3)
On a regression of simple dataset, e.g., iris, the prediction result looks normal for nngp kernel but is full-nan for ntk kernel. For a complex dataset, the 'nngp' will also generates nan prediction, if I add one Dense and Relu layer for the NN. I'm so curious about when/why prediction may has nan values and how to debug it.
By the way, can I save a pre-trained model and then load it from disk, as tf or torch does? Thanks in advance and looking forward for replies.