When using a combination kernel together with multiple kernels of the same type (e.g. RBF), with the intent to learn multiple lengthscales, their values remain identical after optimization.
This is because of their symmetry due to their hyperparameters being initialised to exactly the same values.
Bug Report
JaxKern version: 0.0.5
When using a combination kernel together with multiple kernels of the same type (e.g. RBF), with the intent to learn multiple lengthscales, their values remain identical after optimization.
This is because of their symmetry due to their hyperparameters being initialised to exactly the same values.
To resolve it, I initialise them slightly differently to break the symmetries: see PR https://github.com/JaxGaussianProcesses/JaxKern/pull/42
And allow each kernel in the combination kernel to get its own key: see PR https://github.com/JaxGaussianProcesses/JaxKern/pull/41
Nice work on this library (and gpjax)! Looks great :)