google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.29k stars 227 forks source link

NaN occurs when backprop through the kernel including ReLu layer #88

Closed ZuowenWang0000 closed 3 years ago

ZuowenWang0000 commented 3 years ago

Hello, I have encountered a problem when I try to get the gradient of some loss function with respect to some input variable x, I got a NaN after several iterations. And this only appears when the NTK is including ReLu layer. I've tried Erf or Sigmoid both don't have this problem.

The kernel function I am getting from:

  self.init_fn, self.f, self.kernel_fn = ntstax.serial(
      ntstax.Dense(1, parameterization='ntk'),
      ntstax.Relu(do_backprop=True, do_stabilize=True),
      ntstax.Dense(1, parameterization='ntk')
  )

And I try to grad via: grads = grad(model_loss, argnums=1)(params, (x, y))[0]

and model_loss = lambda params, (x,y) : loss_func(pred(params, x), y)

btw, Is the _safe_sqrt function here: https://github.com/google/neural-tangents/blob/c6f759d116d0128db94d1024612e81eb56e77e7f/neural_tangents/stax.py#L3847 is not back_prop safe? We might need np.where to np.maximum, just like in _sqrt.

Thanks!

romanngg commented 3 years ago

Thanks for the report - could you check if this^ commit helps? It should remove the do_backprop argument and fix nans, as well as improve numerical stability of differentiating nonlinearities - lmk if this works!

ZuowenWang0000 commented 3 years ago

Thanks for the report - could you check if this^ commit helps? It should remove the do_backprop argument and fix nans, as well as improve numerical stability of differentiating nonlinearities - lmk if this works!

ReLU works now. thanks for the timely update!