google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

Defining Own Kernel Function in Neural Tangents #172

Open snehjp2 opened 1 year ago

snehjp2 commented 1 year ago

Hello! I was wondering if it's possible for one to write their own kernel_fn in neural tangents to do regular kernel regression (for example, a Gaussian kernel). Naively just writing a new_kernel_fn and compiling it using jit raises an error when calling `predict_fn'. If this is possible, I'm sure I'm missing something important in making the function compatible with nt and stax backend.

edit: I see now that this may be more appropriate for discussions. Can move this thread there if appropriate!

romanngg commented 1 year ago

AFAIK all but one nt.predict functions should work with any kernel_fn, since they accept as inputs kernels; these kernels you can compute with any function you want.

nt.predict.gradient_descent_ms_ensemble indeed requires kernel_fn to have a specific signature / return type, see docs at https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.predict.gradient_descent_mse_ensemble.html#neural_tangents.predict.gradient_descent_mse_ensemble i.e. it must return a namedtuple with ntk and/or nngp attributes, or the Kernel dataclass object. It requires this specific structure since some settings (e.g. covariance of infinite width GD-trained network outputs) requires both NTK and NNGP to be computed. For regular kernel regression it's probably easiest to use nt.predict.gp_inferece. Lmk if this helps!