google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.27k stars 225 forks source link

NTK for Inputs on Hypersphere #143

Closed lkskstlr closed 2 years ago

lkskstlr commented 2 years ago

Thank you very much for this excellent library and your research, both are highly useful!

Tancik et al. [1, Remark below Eqn. (2)] and Jacot et al. [2, Proof of Proposition 2] both mention that the neural tangent kernel k(x,y) can be simplified as h(<x, y>) if |x| = |y| = 1 holds. I thus wanted to ask, if neural tangents does maybe already support this internally, or if it would be possible to support this feature. A scalar function might offer runtime benefits and is easier to deal with theoretically. I checked experimentally that indeed the current kernel_fn fulfills this property. It would also be super awesome to get somewhat symbolic access to h(<x, y>).

Thanks a lot Lukas

[1] Tancik et al.: Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains [2]: Jacot et al. Neural Tangent Kernel:Convergence and Generalization in Neural Network

romanngg commented 2 years ago

In our codebase <x, y> corresponds to Kernel.nngp, and |x| and |y| are Kernel.cov1 and Kernel.cov2: https://github.com/google/neural-tangents/blob/6d5cbc8328afa63e5d17b773bd68b48026463a97/neural_tangents/_src/utils/kernel.py#L32

So perhaps what you could do is something like this:

_, _, kernel_fn_id = nt.stax.Identity()

# x, y batches of inputs of shapes (N_x, ...), (N_y, ...)
kernel_input = kernel_fn_id(x, y)

_, _, kernel_fn = nt.stax.serial(<your network definition here>)

# xy is must be a matrix of dot products of shape (N_x, N_y, ...)
def kernel_fn_xy(xy):
  kernel_input_new = kernel_input.replace(nngp=xy)
  return kernel_fn(kernel_input_new)

Now kernel_fn_xy is a function of xy only, you can differentiate it, etc, and cov1 and cov2 are fixed constants from computing the kernel_input. Now, this doesn't give you much time or memory savings, but AFAIK asymptotically (in N_x, N_y, and input shapes) the cost of evaluating the "dot-product only" kernel would be equivalent to evaluating the whole kernel (but lmk if you have examples where you expect large memory savings; but if N_x = N_y = 1, then the difference can be ~3X).

Also, I am a bit in a hurry and omitting some details now, and in some cases you may need to construct kernel_input explicitly as kernel_input = nt._src.utils.Kernel(nngp=xy, ...) for more flexibility. But lmk if the above idea sounds reasonable to you in general, and if you have further questions!

lkskstlr commented 2 years ago

@romanngg Thanks a lot for the helpful reply! I will try to see if this works for us :)