A fast, pure python implementation of the MuyGPs Gaussian process realization and training algorithm.
25 stars 11 forks source link

Need hierarchical nonstationary kernel implemented into `MuyGPyS` #71

Open bwpriest opened 1 year ago

bwpriest commented 1 year ago

We need to add Amanda and Imène's hierarchical nonstationary kernel into the library and integrate it into all workflows and tests. It will most likely need to be reworked so that it meshes with the existing framework.

bwpriest commented 1 year ago

We'll need to implement a hierarchical kernel class. some of its functions will look like this:

class HierarchicalRBF(KernelFn):
    def __init__(
        self.knot_locs = LatinHypercube(
            feature_count, centered=True
        self.length_scales = [
            _init_hyperparameter(1.0, "fixed", **ls)
            for ls in length_scales
        for i, ls in enumerate(self.length_scales):
            self.hyperparameters["length_scale_" + i] = ls
        self.metric = metric
        # this bit is notional atm
        self.metagp = MetaGP(**meta_gp_kwargs)

    def __call__(self, squared_dists, xs):
        return self._fn(
            length_scales=[ls() for ls in self.length_scales]

    def self._fn(squared_dists, xs, length_scales):
        batch_count = squared_dists.shape[0]
        nn_count = squared_dists.shape[1]
        # get a (batch_count,) list of parameter estimates for each
        # of the xs
        predicted_length_scales = krig_params(
            xs, self.metaGP, self.knot_locs, self.length_scales
        # this will need to be more clever, since we need to support both
        # 2- and 3-D tensors
        ret = np.zeros(batch_count, nn_count, nn_count)
        for i, kern in enumerate(ret):
            kern = _rbf_fn(
                squared_dists[i, :, :],
         return ret

We will also need the accounting code in get_*_opt_fn and get_optim_params.

bwpriest commented 1 year ago

this discussion in #88 is relevant to this topic as well.

bwpriest commented 1 year ago


bwpriest commented 1 year ago

We will need infrastructure similar to MuyGPyS.gp.make_noise_tensor() that takes train_features (or test_features) and batch_indices and returns a (batch_count, feature_count) tensor batch_features containing the feature vectors for each batch element. We will also need a HierachicalNonstationaryHyperparameter class that accepts a (knot_count, feature_count) knot feature matrix (we'll need code to sample these), initial guesses for the hyperparameter values at each of those knots, and hyperparameters of the lower-level GP.

I think that we should start with the lower-level GP as a simple RBF kernel. I think that we will not expose its parameters to optimization for the initial prototype. The __init__ function will need to look something like

def __init__(self, knot_features, knot_values, **kernel_kwargs):
    self.kernel = RBF(**kernel_kwargs)
    self.knot_features = knot_features
    self.knot_values = knot_values
    self.lower_K = self.kernel(pairwise_distances(self.knot_locations))
    self.solve = mm.linalg.solve(self.lower_k, self.knot_values)

I am also not sure whether we will want a separate HierarchicalNonstationaryHyperparameter instance for each hyperparameter, or one to manange all of the free hyperparameters. We'll probably want to start with one that just manipulates the length_scale for RBF for simplicity. So I'll assume that we are only learning 1 hyperparameter in the rest of this discussion.

The HierarchicalNonstationaryHyperparameter class will need a __call__ function that looks something like

def __call__(self, batch_features) -> mm.ndarray:
    lower_Kcross = self.kernel(crosswise_distances(batch_features, self.knot_locations))
    return lower_Kcross @ self.solve

This function returns a (batch_count, 1) vector of hyperparameter values.

We will need to modify MuyGPyS._src.gp.kernels._rbf_fn so that it can accept a (batch_count, 1) length_scale parameter in addition to the (batch_count, ...) squared_dists parameter and apply each length scale to the corresponding block of the squared_dists tensor. There is probably an elegant way to do this but it will require a little investigation.

We'll probably want a function modifier that then collapses this new form into the scalar form that we are currently using for simplicity.

If all of that is in place, we should be able to modify the RBF.__call__ function to

def __call__(self, diffs, batch_features=None) -> mm.ndarray:
    return self._fn(diffs, length_scale=self.length_scale(batch_features))

(here we will need to have modified the Hyperparameter.__call__ signature to include **kwargs). We will also have to modify several of our __init__ functions to support the creation of NonstationaryHierarchicalHyperparameters. If I am correct, all of this infrastructure will allow us to create nonstationary kernels using our current infrastructure. One of the nice features here is that we get nonstationary + isotropic and nonstationary + anisotropic distortions "for free", since the nonstationarity is relegated to the hyperparameter object and is independent of the distortion model.

Instrumenting the knot hyperparameter values for optimization is another issue that will probably require some refactoring.

bwpriest commented 1 year ago

We will need to modify MuyGPyS._src.gp.kernels._rbf_fn so that it can accept a (batch_count, 1) length_scale parameter in addition to the (batch_count, ...) squared_dists parameter and apply each length scale to the corresponding block of the squared_dists tensor.

This is out of date. length_scale is now handled by the DistortionModel, which is very different from the workflow at the time that I wrote this comment.

We need to modify apply_distortion and _optional_invoke_param in a manner similar to the following:

def _optional_invoke_param(
    param: Union[ScalarHyperparameter, float], 
    batch_features: Optional[mm.ndarray] = None,
) -> float:
    if isinstance(param, ScalarHyperparameter):
        return param()
    if isinstance(param, HierarchicalNonstationaryHyperparameter):
        return param(batch_features)
    return param
def apply_distortion(distortion_fn: Callable, **length_scales):
    def distortion_applier(fn: Callable):
        def distorted_fn(diffs, *args, batch_features=None, **kwargs):
            for ls in length_scales:
                    ls, _optional_invoke_param(length_scales[ls], batch_features=batch_features)
            return fn(
                distortion_fn(diffs, **inner_kwargs), *args, **outer_kwargs
        return distorted_fn
    return distortion_applier

These changes will allow the returned function to accept an additional kwarg, batch_features. We'll then need to modify RBF.__call__ so that it can accept this additional kwarg and pass it to its RBF._fn, which is the output of apply_distortion. I think that will allow us to solve this problem for isotropic kernels.

Instrumenting the knot hyperparameter values for optimization is another issue that will probably require some refactoring.

This should be easier now. We just need to add get_optim_params and get_opt_fn methods for HierarchicalNonstationaryHyperparameter that get called by the the functions of the same name in IsotropicDistortion, which will need to be modified to handle both the scalar (current behavior) case and this new case. We'll probably also need similar functions to apply_scalar_hyperparameter and append_scalar_optim_params_list. That will most likely get complicated.

bwpriest commented 1 year ago

Here's the general outline of how to add get_opt_params to the hierarchical parameter. The trick comes by treating the knot_values as a Dict[str, ScalarHyperparameter], similar to how AnisotropicDistortion handles its length_scale. This dict should probably just have the keys "0", "1", et cetera (we'll see why below). So first, we need to modify the constructor so that it expects the knot_values to be this dict, and to throw errors otherwise (see AnisotropicDistortion). We'll have to do something very similar to:


to convert this dict to a vector for use in the solve. We will not need to include the batch_features part here though, as the knot_values are guaranteed to be scalar hyperparameters.

Once this is done, we need HierarchicalNonstationaryHyperparameter.append_lists to add each free knot parameter to the opt lists. This will look something like

    def append_lists(
        name: str,
        names: List[str],
        params: List[float],
        bounds: List[Tuple[float]],
    for index, param in self.knot_values.items():
        param.append_lists(name + "_knot" + index, names, params, bounds)

This will add hyperparameter entries for each knot value. For a hierarchical length_scale, the entries will have names like length_scale_knot0, length_scale_knot1, et cetera.

bwpriest commented 1 year ago

We will also need to modify how the __call__ method works, because we need it to be able to accept new values for the knots. This can be achieved by adding **kwargs to the signature, and then checking the kwargs dict to see if anything ends with _knot#, and update the appropriate value. I don't think that this will work if we are optimizing multiple hierarchical hyperparameters at once, so we will need to make that more general in the future. We will also need to recompute self.solve if any of the hyperparameters get updated.

bwpriest commented 1 year ago

KernelFn has a _hyperparameters member, to which all kernel parameters get added. We currently construct a new muygps object at the end of optimization with




assumes that the hyperparameter has a scalar value, and so we will need to check if it is hierarchical, and if so do something like

loc = key.find("_knot")
name = key[:loc]
index = key[loc + 5:]
param = ret.kernel._hyperparamters[name]
param.set_knot("index", val)

and we'll need to add the appropriate set_knot method to HierarchicalNonstationaryHyperparameter.

bwpriest commented 1 year ago

I'm less immediately certain of what to do for get_opt_fn outside of what we have already discussed. We can revisit it once you've worked out the above. @igoumiri

bwpriest commented 1 year ago

I merged PR #154. It is a step in the right direction, although the obj_fn that it produces seems to not be sensitive to length_scale_knot# kwargs (see the tutorial notebook). The next step is to fix this, as the kernel tensors should be sensitive, which should also affect the outputs of obj_fn.

igoumiri commented 1 year ago

I fixed it in https://github.com/LLNL/MuyGPyS/pull/156 but some issues remain.