facebook / Ax

Adaptive Experimentation Platform
https://ax.dev
MIT License
2.38k stars 310 forks source link

structured Gaussian processes, deep kernel learning, and incorporating domain knowledge #1189

Closed sgbaird closed 1 year ago

sgbaird commented 2 years ago

I probably lack the understanding and the language required to talk about this effectively, so here are a few follow-up questions.

From my basic understanding, it's functionally similar to performing BO over a VAE latent space, except that the latent space embeddings aren't entirely as fixed, and the manifold itself is learned based on what a deep kernel learning (?) model decides is "useful" or not. On a higher level, I've been told it's useful for incorporating physical insight/domain knowledge (e.g. physical models) into active learning.

I'm asking based on some discussion with Sergei Kalinin on DKL models they've been applying in microscopy settings and how it applies to other domains. See e.g. https://arxiv.org/abs/2205.15458

Related:

from Twitter search of deep kernel learning

From https://github.com/ziatdinovmax/gpax:

The limitation of the standard GP is that it does not usually allow for the incorporation of prior domain knowledge and can be biased toward a trivial interpolative solution. Recently, we introduced a structured Gaussian Process (sGP), where a classical GP is augmented by a structured probabilistic model of the expected system’s behavior. This approach allows us to balance the flexibility of the non-parametric GP approach with a rigid structure of prior (physical) knowledge encoded into the parametric model. Implementation-wise, we substitute a constant/zero prior mean function in GP with a probabilistic model of the expected system's behavior.

The limitation of the standard GP is that it does not usually allow for the incorporation of prior domain knowledge and can be biased toward a trivial interpolative solution. Recently, we introduced a structured Gaussian Process (sGP), where a classical GP is augmented by a structured probabilistic model of the expected system’s behavior. This approach allows us to balance the flexibility of the non-parametric GP approach with a rigid structure of prior (physical) knowledge encoded into the parametric model. Implementation-wise, we substitute a constant/zero prior mean function in GP with a probabilistic model of the expected system's behavior. ... For example, if we have prior knowledge that our objective function has a discontinuous 'phase transition', and a power law-like behavior before and after this transition, we may express it using a simple piecewise function

import jax.numpy as jnp
def piecewise(x: jnp.ndarray, params: Dict[str, float]) -> jnp.ndarray:
    """Power-law behavior before and after the transition"""
    return jnp.piecewise(
        x, [x < params["t"], x >= params["t"]],
        [lambda x: x**params["beta1"], lambda x: x**params["beta2"]])

where jnp corresponds to jax.numpy module. This function is deterministic. To make it probabilistic, we put priors over its parameters with the help of NumPyro


import numpyro
from numpyro import distributions
def piecewise_priors():
    # Sample model parameters
    t = numpyro.sample("t", distributions.Uniform(0.5, 2.5))
    beta1 = numpyro.sample("beta1", distributions.Normal(3, 1))
    beta2 = numpyro.sample("beta2", distributions.Normal(3, 1))
    # Return sampled parameters as a dictionary
    return {"t": t, "beta1": beta1, "beta2": beta2}

Feel free to close as this is just a discussion post, and no worries if this doesn't fit well within the scope of Ax/BoTorch. Curious to hear your thoughts, if any!

Balandat commented 2 years ago

It seems like there is a lot of stuff going on in the papers that you mentioned, using and combining a number of different approaches. I am not sure the term "structured GP" is universally accepted - seems like this is used kind of a like a catch-all for GPish models that incorporate domain knowledge. The place I usually see this used is in exploiting structure in kernel / covariance matrices in order to speed up inference.

As to the more general interpretation, there are lots of flavors of this. We have done some of this on our end, including things like DKL and latent-space BO (@Ryan-Rhys has done lots of this), semiparametric GP models (@bletham), and transfer-learning type modeling in which the a model fit across a bunch of related data is included as an informative prior.

At a high level, the more domain knowledge you have and the more specific you gan get the better a model you can construct for the specific use case at hand. From an implementation perspective, most of the basic concepts for this exist in botorch, but it's rather hard to build a generic interface for these kinds of models at the level of Ax. As always it's a tradeoff between customizability and usability.

sgbaird commented 2 years ago

@Balandat thanks for the great info, and thanks for taking a look at some of the resources I linked.

It seems like there is a lot of stuff going on in the papers that you mentioned, using and combining a number of different approaches. I am not sure the term "structured GP" is universally accepted - seems like this is used kind of a like a catch-all for GPish models that incorporate domain knowledge. ...

Good to know about the terminology.

... The place I usually see this used is in exploiting structure in kernel / covariance matrices in order to speed up inference.

Interesting - I might need to take a look.

As to the more general interpretation, there are lots of flavors of this. We have done some of this on our end, including things like DKL and latent-space BO (@Ryan-Rhys has done lots of this), semiparametric GP models (@bletham), ...

@Ryan-Rhys I'm familiar with the BoTorch VAE+BO example. Any resources for DKL and other latent-space BO that you've worked with and might be able to share? @bletham, also interested in the semiparametric GP models. Had trouble finding an example in Ax or BoTorch.

and transfer-learning type modeling in which the a model fit across a bunch of related data is included as an informative prior.

Multi-task and multi-fidelity, contextual variables, and custom featurizers definitely come to mind in this respect. https://github.com/facebook/Ax/issues/1038. Also nonlinear constraints https://github.com/facebook/Ax/issues/153. Related discussion on domain knowledge https://github.com/facebook/Ax/issues/828.

At a high level, the more domain knowledge you have and the more specific you gan get the better a model you can construct for the specific use case at hand. From an implementation perspective, most of the basic concepts for this exist in botorch, but it's rather hard to build a generic interface for these kinds of models at the level of Ax. As always it's a tradeoff between customizability and usability.

Agreed!

bernardbeckerman commented 1 year ago

closing as discussion.