Open bwpriest opened 1 year ago
We'll need to implement a hierarchical kernel class. some of its functions will look like this:
class HierarchicalRBF(KernelFn):
def __init__(
self,
length_scales,
metric,
feature_count,
**meta_gp_kwargs,
):
super().__init__()
self.knot_locs = LatinHypercube(
feature_count, centered=True
).random(len(length_scales))
self.length_scales = [
_init_hyperparameter(1.0, "fixed", **ls)
for ls in length_scales
]
for i, ls in enumerate(self.length_scales):
self.hyperparameters["length_scale_" + i] = ls
self.metric = metric
# this bit is notional atm
self.metagp = MetaGP(**meta_gp_kwargs)
def __call__(self, squared_dists, xs):
return self._fn(
squared_dists,
xs,
length_scales=[ls() for ls in self.length_scales]
)
@staticmethod
def self._fn(squared_dists, xs, length_scales):
batch_count = squared_dists.shape[0]
nn_count = squared_dists.shape[1]
# get a (batch_count,) list of parameter estimates for each
# of the xs
predicted_length_scales = krig_params(
xs, self.metaGP, self.knot_locs, self.length_scales
)
# this will need to be more clever, since we need to support both
# 2- and 3-D tensors
ret = np.zeros(batch_count, nn_count, nn_count)
for i, kern in enumerate(ret):
kern = _rbf_fn(
squared_dists[i, :, :],
predicted_length_scales[i],
)
return ret
We will also need the accounting code in get_*_opt_fn
and get_optim_params
.
this discussion in #88 is relevant to this topic as well.
@igoumiri
We will need infrastructure similar to MuyGPyS.gp.make_noise_tensor()
that takes train_features
(or test_features
) and batch_indices
and returns a (batch_count, feature_count)
tensor batch_features
containing the feature vectors for each batch element. We will also need a HierachicalNonstationaryHyperparameter
class that accepts a (knot_count, feature_count)
knot feature matrix (we'll need code to sample these), initial guesses for the hyperparameter values at each of those knots, and hyperparameters of the lower-level GP.
I think that we should start with the lower-level GP as a simple RBF
kernel. I think that we will not expose its parameters to optimization for the initial prototype. The __init__
function will need to look something like
def __init__(self, knot_features, knot_values, **kernel_kwargs):
self.kernel = RBF(**kernel_kwargs)
self.knot_features = knot_features
self.knot_values = knot_values
self.lower_K = self.kernel(pairwise_distances(self.knot_locations))
self.solve = mm.linalg.solve(self.lower_k, self.knot_values)
I am also not sure whether we will want a separate HierarchicalNonstationaryHyperparameter
instance for each hyperparameter, or one to manange all of the free hyperparameters. We'll probably want to start with one that just manipulates the length_scale
for RBF
for simplicity. So I'll assume that we are only learning 1 hyperparameter in the rest of this discussion.
The HierarchicalNonstationaryHyperparameter
class will need a __call__
function that looks something like
def __call__(self, batch_features) -> mm.ndarray:
lower_Kcross = self.kernel(crosswise_distances(batch_features, self.knot_locations))
return lower_Kcross @ self.solve
This function returns a (batch_count, 1)
vector of hyperparameter values.
We will need to modify MuyGPyS._src.gp.kernels._rbf_fn
so that it can accept a (batch_count, 1)
length_scale
parameter in addition to the (batch_count, ...)
squared_dists
parameter and apply each length scale to the corresponding block of the squared_dists
tensor. There is probably an elegant way to do this but it will require a little investigation.
We'll probably want a function modifier that then collapses this new form into the scalar form that we are currently using for simplicity.
If all of that is in place, we should be able to modify the RBF.__call__
function to
def __call__(self, diffs, batch_features=None) -> mm.ndarray:
return self._fn(diffs, length_scale=self.length_scale(batch_features))
(here we will need to have modified the Hyperparameter.__call__
signature to include **kwargs
). We will also have to modify several of our __init__
functions to support the creation of NonstationaryHierarchicalHyperparameter
s. If I am correct, all of this infrastructure will allow us to create nonstationary kernels using our current infrastructure. One of the nice features here is that we get nonstationary + isotropic and nonstationary + anisotropic distortions "for free", since the nonstationarity is relegated to the hyperparameter object and is independent of the distortion model.
Instrumenting the knot hyperparameter values for optimization is another issue that will probably require some refactoring.
We will need to modify
MuyGPyS._src.gp.kernels._rbf_fn
so that it can accept a(batch_count, 1)
length_scale
parameter in addition to the(batch_count, ...)
squared_dists
parameter and apply each length scale to the corresponding block of thesquared_dists
tensor.
This is out of date. length_scale
is now handled by the DistortionModel
, which is very different from the workflow at the time that I wrote this comment.
We need to modify apply_distortion
and _optional_invoke_param
in a manner similar to the following:
def _optional_invoke_param(
param: Union[ScalarHyperparameter, float],
batch_features: Optional[mm.ndarray] = None,
) -> float:
if isinstance(param, ScalarHyperparameter):
return param()
if isinstance(param, HierarchicalNonstationaryHyperparameter):
return param(batch_features)
return param
def apply_distortion(distortion_fn: Callable, **length_scales):
def distortion_applier(fn: Callable):
def distorted_fn(diffs, *args, batch_features=None, **kwargs):
...
for ls in length_scales:
inner_kwargs.setdefault(
ls, _optional_invoke_param(length_scales[ls], batch_features=batch_features)
)
...
return fn(
distortion_fn(diffs, **inner_kwargs), *args, **outer_kwargs
)
return distorted_fn
return distortion_applier
These changes will allow the returned function to accept an additional kwarg, batch_features
. We'll then need to modify RBF.__call__
so that it can accept this additional kwarg and pass it to its RBF._fn
, which is the output of apply_distortion
. I think that will allow us to solve this problem for isotropic kernels.
Instrumenting the knot hyperparameter values for optimization is another issue that will probably require some refactoring.
This should be easier now. We just need to add get_optim_params
and get_opt_fn
methods for HierarchicalNonstationaryHyperparameter
that get called by the the functions of the same name in IsotropicDistortion
, which will need to be modified to handle both the scalar (current behavior) case and this new case. We'll probably also need similar functions to apply_scalar_hyperparameter
and append_scalar_optim_params_list
. That will most likely get complicated.
Here's the general outline of how to add get_opt_params
to the hierarchical parameter. The trick comes by treating the knot_values
as a Dict[str, ScalarHyperparameter]
, similar to how AnisotropicDistortion
handles its length_scale
. This dict should probably just have the keys "0"
, "1"
, et cetera (we'll see why below). So first, we need to modify the constructor so that it expects the knot_values
to be this dict, and to throw errors otherwise (see AnisotropicDistortion
). We'll have to do something very similar to:
to convert this dict to a vector for use in the solve. We will not need to include the batch_features
part here though, as the knot_values
are guaranteed to be scalar hyperparameters.
Once this is done, we need HierarchicalNonstationaryHyperparameter.append_lists
to add each free knot parameter to the opt lists. This will look something like
def append_lists(
self,
name: str,
names: List[str],
params: List[float],
bounds: List[Tuple[float]],
):
for index, param in self.knot_values.items():
param.append_lists(name + "_knot" + index, names, params, bounds)
This will add hyperparameter entries for each knot value. For a hierarchical length_scale
, the entries will have names like length_scale_knot0
, length_scale_knot1
, et cetera.
We will also need to modify how the __call__
method works, because we need it to be able to accept new values for the knots. This can be achieved by adding **kwargs
to the signature, and then checking the kwargs
dict to see if anything ends with _knot#
, and update the appropriate value. I don't think that this will work if we are optimizing multiple hierarchical hyperparameters at once, so we will need to make that more general in the future. We will also need to recompute self.solve
if any of the hyperparameters get updated.
KernelFn
has a _hyperparameters
member, to which all kernel parameters get added. We currently construct a new muygps object at the end of optimization with
However,
assumes that the hyperparameter has a scalar value, and so we will need to check if it is hierarchical, and if so do something like
loc = key.find("_knot")
name = key[:loc]
index = key[loc + 5:]
param = ret.kernel._hyperparamters[name]
param.set_knot("index", val)
and we'll need to add the appropriate set_knot
method to HierarchicalNonstationaryHyperparameter
.
I'm less immediately certain of what to do for get_opt_fn
outside of what we have already discussed. We can revisit it once you've worked out the above. @igoumiri
I merged PR #154. It is a step in the right direction, although the obj_fn
that it produces seems to not be sensitive to length_scale_knot#
kwargs (see the tutorial notebook). The next step is to fix this, as the kernel tensors should be sensitive, which should also affect the outputs of obj_fn
.
I fixed it in https://github.com/LLNL/MuyGPyS/pull/156 but some issues remain.
We need to add Amanda and Imène's hierarchical nonstationary kernel into the library and integrate it into all workflows and tests. It will most likely need to be reworked so that it meshes with the existing framework.