Closed lkskstlr closed 2 years ago
In our codebase <x, y>
corresponds to Kernel.nngp
, and |x|
and |y|
are Kernel.cov1
and Kernel.cov2
: https://github.com/google/neural-tangents/blob/6d5cbc8328afa63e5d17b773bd68b48026463a97/neural_tangents/_src/utils/kernel.py#L32
So perhaps what you could do is something like this:
_, _, kernel_fn_id = nt.stax.Identity()
# x, y batches of inputs of shapes (N_x, ...), (N_y, ...)
kernel_input = kernel_fn_id(x, y)
_, _, kernel_fn = nt.stax.serial(<your network definition here>)
# xy is must be a matrix of dot products of shape (N_x, N_y, ...)
def kernel_fn_xy(xy):
kernel_input_new = kernel_input.replace(nngp=xy)
return kernel_fn(kernel_input_new)
Now kernel_fn_xy
is a function of xy
only, you can differentiate it, etc, and cov1
and cov2
are fixed constants from computing the kernel_input
. Now, this doesn't give you much time or memory savings, but AFAIK asymptotically (in N_x
, N_y
, and input shapes) the cost of evaluating the "dot-product only" kernel would be equivalent to evaluating the whole kernel (but lmk if you have examples where you expect large memory savings; but if N_x = N_y = 1
, then the difference can be ~3X).
Also, I am a bit in a hurry and omitting some details now, and in some cases you may need to construct kernel_input
explicitly as kernel_input = nt._src.utils.Kernel(nngp=xy, ...)
for more flexibility. But lmk if the above idea sounds reasonable to you in general, and if you have further questions!
@romanngg Thanks a lot for the helpful reply! I will try to see if this works for us :)
Thank you very much for this excellent library and your research, both are highly useful!
Tancik et al. [1, Remark below Eqn. (2)] and Jacot et al. [2, Proof of Proposition 2] both mention that the neural tangent kernel
k(x,y)
can be simplified ash(<x, y>)
if|x| = |y| = 1
holds. I thus wanted to ask, ifneural tangents
does maybe already support this internally, or if it would be possible to support this feature. A scalar function might offer runtime benefits and is easier to deal with theoretically. I checked experimentally that indeed the currentkernel_fn
fulfills this property. It would also be super awesome to get somewhat symbolic access toh(<x, y>)
.Thanks a lot Lukas
[1] Tancik et al.: Fourier Features Let Networks Learn High Frequency Functions in Low Dimensional Domains [2]: Jacot et al. Neural Tangent Kernel:Convergence and Generalization in Neural Network