Open snehjp2 opened 1 year ago
AFAIK all but one nt.predict
functions should work with any kernel_fn
, since they accept as inputs kernels; these kernels you can compute with any function you want.
nt.predict.gradient_descent_ms_ensemble
indeed requires kernel_fn
to have a specific signature / return type, see docs at https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.predict.gradient_descent_mse_ensemble.html#neural_tangents.predict.gradient_descent_mse_ensemble
i.e. it must return a namedtuple with ntk
and/or nngp
attributes, or the Kernel
dataclass object. It requires this specific structure since some settings (e.g. covariance of infinite width GD-trained network outputs) requires both NTK and NNGP to be computed. For regular kernel regression it's probably easiest to use nt.predict.gp_inferece
. Lmk if this helps!
Hello! I was wondering if it's possible for one to write their own
kernel_fn
in neural tangents to do regular kernel regression (for example, a Gaussian kernel). Naively just writing anew_kernel_fn
and compiling it usingjit
raises an error when calling `predict_fn'. If this is possible, I'm sure I'm missing something important in making the function compatible with nt and stax backend.edit: I see now that this may be more appropriate for discussions. Can move this thread there if appropriate!