JaxGaussianProcesses / JaxKern

Kernel functions in JAX.
MIT License
8 stars 4 forks source link

Added symmetry-breaking init in rbf #42

Open bodin-e opened 1 year ago

bodin-e commented 1 year ago

The init_params method currently initializes (for example) the rbf kernel deterministically. As such, when using a combination kernel of the same kernel types, the kernels are initialized to be identical. This is problematic as the symmetry between them forces them to stay the same using gradient-based local optimization, preventing e.g. a sum of multiple lengthscales to be learnt (if not breaking the symmetries in some other way).

This adds a random jitter to the RBF kernels lengthscales to break the symmetry. To resolve the same issue for the other kernels, they would need something similar.

The following PR makes the change to pass a unique key to each kernel within the combination kernel: https://github.com/JaxGaussianProcesses/JaxKern/pull/41

thomaspinder commented 1 year ago

Hi @bodin-e Thanks for the PR - this is a good catch that hadn’t occured to me. I wonder if it only makes sense to do this when initialising the lengthscale(s) for a combination kernel? If so, then I’d prefer this change to be move upstream into the init_params() of the Combination kernel class.

This could be handled through a simple if ‘lengthscale’ in params.keys() which would proliferate the stochastic initialisation through to any kernel that is parameterised by a lengthscale, not just the RBF.

bodin-e commented 1 year ago

I wonder if it only makes sense to do this when initialising the lengthscale(s) for a combination kernel? If so, then I’d prefer this change to be move upstream into the init_params() of the Combination kernel class.

Yes, I would expect it only being needed when using a combination kernel. It would be needed in all cases where any two (or more) kernels are identical, receive identical inputs and is depended on in an identical way. Which in practice probably would just be in the combination kernel case (at least I cannot think of another case now).

However, it does not only matter for the lengthscale parameter. It would be sufficient to break the symmetry by using different initialization of any parameter in each kernel, such as for example the lengthscale, but all kernels used within a combination kernel may not have a lengthscale. And I think it could end up cumbersome to maintain a set of parameter names {"lengthscale", ...} to include at least one parameter name from each of the library kernels, although not impossible.

I completely agree with the need for a nicer and more maintainable solution than the one in this PR though. The main purpose of this PR I had in mind was to highlight the issue (it would have been better to include an example within the issue ticket instead though).

Yet, I'm thinking it may be the easiest to maintain and most readable to let the init_params of each kernel initialize itself sensibly, including a stochastic component to it; added to the parameter that it makes sense to add it to for the given kernel (and with a suitable jitter scale for that particular parameter).

For example, in the RBF case we (and a user using the library) can easily read within the function scope that the default lengthscale is 1; and we know that a small jitter constant (say 1e-3 or 1e-4) shouldn't be semantically meaningful for the lengthscale in the RBF (in relation to value 1). In contrast, in the some kernels or setups some parameters can be much more sensitive to its value. If a user uses such as kernel they would easily be able to see how the particular kernel they use is initialised, and can themselves then override the init_params behaviour or use some other way to initialize it.

What do you think of the following? Basically doing exactly what you purpose, which is to reuse initialization logic for the lengthscale (and other parameters when needed). But keeping this logic as separate functions and use them explicitly where needed, to keep high readability and reduce the risk of the user missing what is happening (and getting unexpected outcomes because of it).

Specifically, keeping a file or module with parameter initialize functions somewhere, with functions such as something like: def initialize_lengthscale(key, num_dims, initial_value=1.0, eps=1e-4):, which takes care of both the current initialization logic for that parameter (e.g jnp.array([1.0] * ndims)) but also adding on the stochastic jitter. And then use these parameter initialization functions within each respective init_params functions as and when it is appropriate, such as in matern, rational quadratic, rbf etc. And have and use similar functions for the other parameters and in other kernels. Although it would only be necessary to add stochastic jitter to one of the parameters in each kernel, it wouldn't hurt to add it to more than one either. And if there is a parameter for a kernel somewhere for which it is important no jitter is added its init_params can easily omit doing so.