JaxGaussianProcesses / JaxKern

Kernel functions in JAX.
MIT License
7 stars 3 forks source link

bug: identical hyperparameter init preventing heterogeneous kernels within combinations #43

Open bodin-e opened 1 year ago

bodin-e commented 1 year ago

Bug Report

JaxKern version: 0.0.5

When using a combination kernel together with multiple kernels of the same type (e.g. RBF), with the intent to learn multiple lengthscales, their values remain identical after optimization.

This is because of their symmetry due to their hyperparameters being initialised to exactly the same values.

To resolve it, I initialise them slightly differently to break the symmetries: see PR https://github.com/JaxGaussianProcesses/JaxKern/pull/42

And allow each kernel in the combination kernel to get its own key: see PR https://github.com/JaxGaussianProcesses/JaxKern/pull/41

Nice work on this library (and gpjax)! Looks great :)