pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.11k stars 406 forks source link

Add maximum variance ("maxvar") acquisition functions for active learning #1366

Closed matthewcarbone closed 6 months ago

matthewcarbone commented 2 years ago

🚀 Feature Request

Currently botorch implements the UpperConfidenceBound analytic and MC acquisition functions in e.g. botorch.acquisition.UpperConfidenceBound. This is useful but it might be nice to have a pure "active learning" acquisition function instead of just setting a large beta value.

Motivation

It's true this isn't really "true Bayesian optimization", but I think there is a community that will find this useful. It kinda feels like botorch should have this.

Pitch

Describe the solution you'd like

Implementing a "hacked version" of UCB which does this.

Describe alternatives you've considered

You can simply set the control parameter beta=1e10 or something large to drown out the contribution from the mean, but this feels pretty sloppy to me.

Are you willing to open a pull request?

~Yes absolutely! I can do this for both analytic and MC.~ Unfortunately no.

Additional context

N/A

saitcakmak commented 2 years ago

Hi @x94carbone. I am not that familiar with active learning applications, so I'll let someone else with more context address this. I just wanted to point out that we do support qNegIntegratedPosteriorVariance as a pure-exploration acquisition function.

Balandat commented 2 years ago

It's true this isn't really "true Bayesian optimization", but I think there is a community that will find this useful. It kinda feels like botorch should have this.

Agreed, I think the active learning aspect is interesting and important and we're happy to accept suggestions - and even happier to accept PRs :) These would naturally go into the active_learning submodule.

Just so I understand, this would basically just return the posterior variance and then optimizing the acquisition function would mean finding the point with highest posterior variance?

matthewcarbone commented 2 years ago

@saitcakmak Ah I see. Is this in principle different than if we set the UCB acquisition function beta to some large number? I do see it comes from botorch.acquisition.active_learning, which I've played around with some time ago, but I don't recall it working as smoothly out of the box as just "hacking" UCB. Perhaps I should revisit though, I don't want to duplicate your efforts.

@Balandat That is correct!

Edit: actually at the very least I do think this will add something: an analytic version of a "pure exploration" acquisition function.

saitcakmak commented 2 years ago

Is this in principle different than if we set the UCB acquisition function beta to some large number?

I think it is closer in spirit to a look-ahead type acquisition function. Rather than greedily picking the point with the maximum variance, it tries to maximize the reduction in the total (integrated) variance in the search space from evaluating a given point. In that sense, it explicitly considers the effects beyond the candidate itself, which is different than what UCB does.

matthewcarbone commented 2 years ago

I think it is closer in spirit to a look-ahead type acquisition function. Rather than greedily picking the point with the maximum variance, it tries to maximize the reduction in the total (integrated) variance in the search space from evaluating a given point. In that sense, it explicitly considers the effects beyond the candidate itself, which is different than what UCB does.

This does sound different then. Perhaps MaxVar/qMaxVar is simpler, but it does sound like a new feature.

What I can say is that most have heard of MaxVar. I think qNegIntegratedPosteriorVariance is a bit more complex, at least for what I know some of my colleagues use/need.

What do you think? Would you like me to open a PR for this?

Balandat commented 2 years ago

What do you think? Would you like me to open a PR for this?

I think MaxVar is a bit of a misnomer, since the acquisition function is really just posterior variance at the query point(s). So it should be (q)PosteriorVariance - the "max" is just b/c we actually optimize it (that's why we call things PosteriorMean rather than MaxPosteriorMean - semantics). We could add this to the active_learning module if you think having this would be helpful for folks.

matthewcarbone commented 2 years ago

I think MaxVar is a bit of a misnomer, since the acquisition function is really just posterior variance at the query point(s). So it should be (q)PosteriorVariance - the "max" is just b/c we actually optimize it (that's why we call things PosteriorMean rather than MaxPosteriorMean - semantics). We could add this to the active_learning module if you think having this would be helpful for folks.

Sure fair enough, that's consistent with the other botorch methods. I'll implement and open a PR a bit later. 👍

eytan commented 2 years ago

One could add an acquisition function that just maximizes posterior variance. This is often used as a baseline in active learning papers and generally performs the worst. I would highly recommend not using this acquisition function for any actual active learning tasks. Integrated mean squared error (which is pretty much the same as integrated posterior variance) is commonly used in active learning within the simulation community.

On Wed, Aug 24, 2022 at 2:34 PM Matthew Carbone @.***> wrote:

I think MaxVar is a bit of a misnomer, since the acquisition function is really just posterior variance at the query point(s). So it should be (q)PosteriorVariance - the "max" is just b/c we actually optimize it (that's why we call things PosteriorMean rather than MaxPosteriorMean - semantics). We could add this to the active_learning module if you think having this would be helpful for folks.

Sure fair enough, that's consistent with the other botorch methods. I'll implement and open a PR a bit later. 👍

— Reply to this email directly, view it on GitHub https://github.com/pytorch/botorch/issues/1366#issuecomment-1226089629, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAW34J3YZPUW6LPJ2WVP7TV2ZTKXANCNFSM57QEDASA . You are receiving this because you are subscribed to this thread.Message ID: @.***>

matthewcarbone commented 2 years ago

@eytan sure, it's usually just a baseline, but it's an important one, and sometimes it's good enough.

matthewcarbone commented 2 years ago

All, I don't feel comfortable signing Meta's required form to open the PR (given my position it's not entirely obvious that I can, I'm no lawyer and I don't pretend to be). I apologize, I didn't realize there was a required "extra step" at first.

Anyways, this is more or less what I had in mind:

class PosteriorVariance(AnalyticAcquisitionFunction):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @t_batch_mode_transform(expected_q=1)
    def forward(self, X):
        posterior = self.model.posterior(
            X=X,
            posterior_transform=self.posterior_transform
        )
        mean = posterior.mean
        variance = posterior.variance
        view_shape = (
            mean.shape[:-2] if mean.shape[-2] == 1 else mean.shape[:-1]
        )
        return variance.view(view_shape)

class qPosteriorVariance(MCAcquisitionFunction):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    @concatenate_pending_points
    @t_batch_mode_transform()
    def forward(self, X):
        posterior = self.model.posterior(
            X=X,
            posterior_transform=self.posterior_transform
        )
        objective = self.objective(self.sampler(posterior), X=X)
        mu = objective.mean(dim=0)
        samples = (objective - mu).abs()
        return samples.max(dim=-1)[0].mean(dim=0)

Please feel free to implement something like the following classes if you feel the meet the bar (or to edit them in ways you feel appropriate). They're largely just built on top of the UCB acquisition functions in your code. You should be able to add this to botorch.acquisition.active_learning with no hiccups.

takafusui commented 1 year ago

Hi, does anybody have a sample code to do active learning using BoTorch so that we increase the global model accuracy? I want to focus on the pure-exploration part, and I am also reluctant to set a large beta in the UpperConfidenceBound analytic acquisition function. I realized the qNegIntegratedPosteriorVariance acquisition function, so I highly welcome something with the qNegIntegratedPosteriorVariance acquisition function as well. Thank you in advance.

matthewcarbone commented 1 year ago

@takafusui incidentally, this

I am also reluctant to set a large beta in the UpperConfidenceBound analytic acquisition function

Does actually work just fine. I usually use beta=100 or so for "active learning". For properly scaled inputs/outputs this seems to be sufficient.

takafusui commented 1 year ago

Hi @matthewcarbone, thank you for your comment. What did you mean 'properly scaled inputs/outputs'? For instance, I usually standardize inputs when fitting a GP model, then scale back outputs when computing posterior mean and variance.

matthewcarbone commented 1 year ago

@takafusui so in the case of just setting beta to be large, how large depends on your data. You need beta to be large enough to "outcompete" the value from the mean of your probabilistic model. When I say properly scaled output I mean assuming your output is standard scaled, or normalized to -1 -> 1 or something. Then you can have a ballpark idea of "how arbitrarily large" beta needs to be. This is also good practice for numerical stability (as I think you probably know), irrespective of this problem!

eytan commented 1 year ago

If you'd like to do active learning, qNegIntegratedPosteriorVariance should perform better than greedily maximizing the posterior variance, since it directly targets reduction in global variance.

Matthew, have you done any benchmarks to compare your proposed baseline and qNegIntegratedPosteriorVariance?

On Wed, Jan 18, 2023 at 9:49 AM Takafumi Usui @.***> wrote:

Hi, does anybody have a sample code to do active learning using BoTorch so that we increase the global model accuracy? I want to focus on the pure-exploration part, and I am also reluctant to set a large beta in the UpperConfidenceBound analytic acquisition function. I realized the qNegIntegratedPosteriorVariance acquisition function, so I highly welcome something with the qNegIntegratedPosteriorVariance acquisition function as well. Thank you in advance.

— Reply to this email directly, view it on GitHub https://github.com/pytorch/botorch/issues/1366#issuecomment-1387194279, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAW34MTVVN3TF6PFZXPZQLWS77G3ANCNFSM57QEDASA . You are receiving this because you were mentioned.Message ID: @.***>

matthewcarbone commented 1 year ago

@eytan no, but I've benchmarked it against my own custom acquisition function which does pure active learning and they work about the same.

takafusui commented 1 year ago

Hi @eytan, @matthewcarbone and all,

I agree with @eytan. Although I started to study active learning very recently, when I reviewed some literature, the integrated mean square error criterion (: IMSE) is a benchmark over the mean square error criterion (: MSE). However, at the same time, I can understand @matthewcarbone's experiences. So I created a test code where I compared the performance of IMSE and MSE. I tried to use qNegIntegratedPosteriorVariance as the IMSE criterion, although I am not sure about my implementation with qNIPV. I define MeanSquareError by myself, which is heavily inherited from Upper Confidence Bound (: UCB) from BoTorch. I deleted the mean of the posterior part. I apply the leave-one-out error as our error metric.

import numpy as np
from smt.sampling_methods import LHS, Random
import matplotlib.pyplot as plt

# PyTorch
import torch
# GPyTorch
import gpytorch
# BoTorch
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from botorch.models.transforms.outcome import Standardize
from botorch.models.transforms.input import InputStandardize
from botorch.acquisition import qNegIntegratedPosteriorVariance
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.optim import optimize_acqf

# Leave-one-out error
from sklearn.model_selection import LeaveOneOut

# Fix random seeds
np.random.seed(123)
torch.manual_seed(123)

# Use double data type
dtype = torch.double

from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.models.model import Model
from botorch.acquisition.objective import PosteriorTransform
from botorch.utils.transforms import t_batch_mode_transform
from typing import Optional
from torch import Tensor

class MeanSquareError(AnalyticAcquisitionFunction):
    r"""Single-outcome mean-square error (MSE).

    Analytic mean square error that focuses on pure-exploration.
    The acquasition function focuses on the posterior variance at the query points.
    Only supports the case of `q=1` (i.e. greedy, non-batch selection of design points).
    The model must be single-outcome.

    `MSE(x) = sigma(x)`, where 'sigma` is the posterior standard deviation.
    """

    def __init__(
        self,
        model: Model,
        posterior_transform: Optional[PosteriorTransform] = None,
        maximize: bool = True,
        **kwargs,
    ) -> None:
        r"""Single-outcome mean-square error.

        Args:
            model: A fitted single-outcome GP model (must be in batch mode if
                candidate sets X will be)
            posterior_transform: A PosteriorTransform. If using a multi-output model,
                a PosteriorTransform that transforms the multi-output posterior into a
                single-output posterior is required.
            maximize: If True, consider the problem a maximization problem.
        """
        super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
        self.maximize = maximize

    @t_batch_mode_transform(expected_q=1)
    def forward(self, X: Tensor) -> Tensor:
        r"""Evaluate the Upper Confidence Bound on the candidate set X.

        Args:
            X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.

        Returns:
            A `(b1 x ... bk)`-dim tensor of Upper Confidence Bound values at the
            given design points `X`.
        """
        posterior = self.model.posterior(
            X=X, posterior_transform=self.posterior_transform
        )
        mean = posterior.mean
        view_shape = mean.shape[:-2] if mean.shape[-2] == 1 else mean.shape[:-1]
        mean = mean.view(view_shape)
        variance = posterior.variance.view(view_shape)
        if self.maximize:
            return variance
        else:
            return variance

# --------------------------------------------------------------------------- #
# Two-dimensional Gaussian peak function
# --------------------------------------------------------------------------- #
def peak_func(x):
    """Two-dimensional Gaussian peak function.

    x.shape: (N, 2)
    """
    ai = torch.tensor([5, 5], dtype=dtype)
    ui = torch.tensor([0.8, 0.8], dtype=dtype)

    _peak_func = torch.exp(
        -torch.sum(ai**2 * (x - ui)**2, axis=1, keepdims=True))

    return _peak_func

# Define the test function on [0, 1]^2
xlimits = np.array([[0, 1], [0, 1]])
# Latin-Hypercube sampling
sampler_LHS = LHS(xlimits=xlimits, random_state=123)
# Uniform distribution
sampler_uniform = Random(xlimits=xlimits)
sampler_list = {'LHS': sampler_LHS, 'uniform': sampler_uniform}

# Test dataset
N_test_X = 1000
test_X = torch.tensor(sampler_list['uniform'](N_test_X))
test_y = peak_func(test_X)

# --------------------------------------------------------------------------- #
# Initialize a GP model
# --------------------------------------------------------------------------- #
def initialize_GPmodel(train_X, train_y):
    """Initialize a GP model.

    trian_X.shape: [N, d]
    train_y.shape: [N, 1]

    Return the fitted GP model.
    """
    # Get the dimensions of train_X and train_y to standardize inputs
    train_X_dim = train_X.shape[-1]
    train_y_dim = train_y.shape[-1]

    # Train the GP model with the standardized inputs
    gp = SingleTaskGP(
        train_X, train_y,
        input_transform=InputStandardize(d=train_X_dim),
        outcome_transform=Standardize(m=train_y_dim)
    )
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(gp.likelihood, gp)

    return mll, gp

# --------------------------------------------------------------------------- #
# Compute the leave-one-out error
# --------------------------------------------------------------------------- #
def compute_errLOO(train_X, train_y):
    """Compute the leave-one-out error.

    train_X.shape: (N, d)
    train_y.shape: (N, 1)

    return the leave-ont-out error
    """
    loo = LeaveOneOut()  # Leave-one-out error estimator instance

    err = torch.empty(loo.get_n_splits(train_X), dtype=dtype)

    for train_idx, test_idx in loo.split(train_X):
        train_loo_X, test_loo_X = train_X[train_idx], train_X[test_idx]
        train_loo_y, test_loo_y = train_y[train_idx], train_y[test_idx]

        # Train the GP model based on train_loo_X and train_loo_y
        mll_loo, gp_loo = initialize_GPmodel(train_loo_X, train_loo_y)

        # Fit the GP model
        fit_gpytorch_model(mll_loo)

        # Make prediction and compute the squared approximation error
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            pred_test_loo_X = gp_loo.posterior(test_loo_X)
            err[test_idx] = (test_loo_y - pred_test_loo_X.mean)**2

    # Leave-one-out error
    errLOO = 1 / loo.get_n_splits(train_X) * torch.sum(err)

    return errLOO

# -------------------------------------------------------------------------- #
# Active learning with IMSE and MSE
# -------------------------------------------------------------------------- #
N_train_X_to_add = 5
N_train_X_list = np.arange(50, 101, N_train_X_to_add)
N_train_X_init = N_train_X_list[0]  # Initial training size

# --------------------------------------------------------------------------- #
print(r'Active learning with the integrated mean square error criterion (IMSE)')
# --------------------------------------------------------------------------- #
# Store errLOO errors
errLOO_IMSE = np.empty(len(N_train_X_list))
# Number of added training sets
N_train_X_IMSE = np.empty_like(errLOO_IMSE)

# Initial training set sampling from LHS
train_X = torch.tensor(sampler_list['LHS'](N_train_X_init))
train_y = peak_func(train_X)
# Fit the GP model
mll, gp = initialize_GPmodel(train_X, train_y)
# Bounds on train_X
bounds = torch.tensor(
    [[0.] * train_X.shape[1], [1.] * train_X.shape[1]], dtype=dtype)

# Active learning with IMSE
for imse_iter_idx in range(len(N_train_X_list)):

    if imse_iter_idx == 0:
        print(r"Compute errLOO for a pilot GP model...")
        # Compute LOO error when we add N_train_X_init number of samples
        errLOO_IMSE[imse_iter_idx] = compute_errLOO(train_X, train_y)
        N_train_X_IMSE[imse_iter_idx] = train_X.shape[0]
    else:
        print(r"Actively learning {} new training data in total...".format(
            imse_iter_idx * N_train_X_to_add))

        # We are sequantially add N_train_X_to_add new training data
        for iter_idx in range(N_train_X_to_add):
            # Quasi-MC base sampler using Sobol sequences
            qmc_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1]), seed=123)

            with torch.no_grad(), gpytorch.settings.fast_pred_var():
                # As we standardized inputs, we must use gp.posterior to get
                # rescaled outputs.
                pred_test_X = gp.posterior(test_X)

            # The points to use for MC-integrating the posterior variance
            qmc_samples = qmc_sampler(pred_test_X)

            # Batch Integrated Negative Posterior Variance for Active Learning
            # mc_points are used for MC-integrating the posterior variance.
            qNIPV = qNegIntegratedPosteriorVariance(model=gp, mc_points=qmc_samples)

            # Optimize the acquisition function
            train_X_add, acq_value = optimize_acqf(
                qNIPV, bounds=bounds, q=1, num_restarts=3, raw_samples=256
            )

            # Compute the new label
            train_y_add = peak_func(train_X_add)

            # Concatenate the training data and the labels to be added
            train_X = torch.cat([train_X, train_X_add], axis=0)
            train_y = torch.cat([train_y, train_y_add], axis=0)

            # Reinitialize the model so it is ready the next iteration
            mll, gp = initialize_GPmodel(train_X, train_y)
            # Refit the GP model
            fit_gpytorch_model(mll)

        # Compute LOO error when we add N_AL_iter number of samples
        errLOO_IMSE[imse_iter_idx] = compute_errLOO(train_X, train_y)
        N_train_X_IMSE[imse_iter_idx] = train_X.shape[0]

# --------------------------------------------------------------------------- #
print(r'Active learning with the mean square error criterion (MSE)')
# We use the MeanSquareError method inherited from UpperConfidenceBound
# --------------------------------------------------------------------------- #
# Store errLOO errors
errLOO_MSE = np.empty(len(N_train_X_list))
# Number of added training sets
N_train_X_MSE = np.empty_like(errLOO_MSE)

# Initial training set sampling from LHS
train_X = torch.tensor(sampler_list['LHS'](N_train_X_init))
train_y = peak_func(train_X)
# Fit the GP model
mll, gp = initialize_GPmodel(train_X, train_y)
# Bounds on train_X
bounds = torch.tensor(
    [[0.] * train_X.shape[1], [1.] * train_X.shape[1]], dtype=dtype)

# Active learning with MSE
for mse_iter_idx in range(len(N_train_X_list)):

    if mse_iter_idx == 0:
        print(r"Compute errLOO for a pilot GP model...")
        # Compute LOO error when we add N_train_X_init number of samples
        errLOO_MSE[mse_iter_idx] = compute_errLOO(train_X, train_y)
        N_train_X_MSE[mse_iter_idx] = train_X.shape[0]
    else:
        print(r"Actively learning {} new training data in total...".format(
            mse_iter_idx * N_train_X_to_add))

        # We are sequantially add N_train_X_to_add new training data
        for iter_idx in range(N_train_X_to_add):
            # Mean square error
            MSE = MeanSquareError(gp)
            # Optimize the acquisition function
            train_X_add, acq_value = optimize_acqf(
                MSE, bounds=bounds, q=1, num_restarts=10, raw_samples=1024
            )

            # Compute the new label
            train_y_add = peak_func(train_X_add)

            # Concatenate the training data and the labels to be added
            train_X = torch.cat([train_X, train_X_add], axis=0)
            train_y = torch.cat([train_y, train_y_add], axis=0)

            # Reinitialize the model so it is ready the next iteration
            mll, gp = initialize_GPmodel(train_X, train_y)
            # Refit the GP model
            fit_gpytorch_model(mll)

        # Compute LOO error when we add N_AL_iter number of samples
        errLOO_MSE[mse_iter_idx] = compute_errLOO(train_X, train_y)
        N_train_X_MSE[mse_iter_idx] = train_X.shape[0]

# --------------------------------------------------------------------------- #
# Plot errLOO for IMSE and MSE
# --------------------------------------------------------------------------- #
fig, ax = plt.subplots(figsize=(9, 6))

ax.plot(N_train_X_IMSE, errLOO_IMSE, color='red', marker='o', label='IMSE')
ax.plot(N_train_X_MSE, errLOO_MSE, color='blue', marker='o', label='MSE')
ax.set_xlim([N_train_X_list.min(), N_train_X_list.max()])
ax.set_xlabel(r'Number of experimental design')
ax.set_ylabel(r'$\epsilon_{\mathrm{LOO}}$')
ax.set_yscale('log')
ax.legend(loc='upper left')
plt.show()
# plt.savefig('errLOO_mse_imse.pdf', bbox_inches='tight')
plt.close()

As you see, with this simple function, MSE works better than IMSE...

As I said, I am not sure about my implementation of qNegIntegratedPosteriorVariance... This is why the error metric differs from our expectations...

I highly appreciate any comments from you.

takafusui commented 1 year ago

I have follow-up comments on my previous post. My colleague told me that, in practical applications, the leave-one-out error of less than 10^{-2} could be sufficient. In the previous post, as my test function is close to 0 in some domains, the leave-one-out error is already very small. In this case, sequential refining does not work as we expect.

I tested with another function. The function is defined on [0, 1]^2. I scale the output to [0, 1] when computing the leave-one-out error.

import numpy as np
from smt.sampling_methods import LHS, Random
import matplotlib.pyplot as plt

# PyTorch
import torch
# GPyTorch
import gpytorch
# BoTorch
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_model
from botorch.models.transforms.outcome import Standardize
from botorch.models.transforms.input import InputStandardize
from botorch.acquisition import qNegIntegratedPosteriorVariance
from botorch.sampling.normal import SobolQMCNormalSampler
from botorch.optim import optimize_acqf

# Leave-one-out error
from sklearn.model_selection import LeaveOneOut
# Pre-processing
from sklearn import preprocessing

# Fix random seeds
np.random.seed(123)
torch.manual_seed(123)

# Use double data type
dtype = torch.double

from botorch.acquisition.analytic import AnalyticAcquisitionFunction
from botorch.models.model import Model
from botorch.acquisition.objective import PosteriorTransform
from botorch.utils.transforms import t_batch_mode_transform
from typing import Optional
from torch import Tensor

class MeanSquareError(AnalyticAcquisitionFunction):
    r"""Single-outcome mean-square error (MSE).

    Analytic mean square error that focuses on pure-exploration.
    The acquasition function focuses on the posterior variance at the query points.
    Only supports the case of `q=1` (i.e. greedy, non-batch selection of design points).
    The model must be single-outcome.

    `MSE(x) = sigma(x)`, where 'sigma` is the posterior standard deviation.
    """

    def __init__(
        self,
        model: Model,
        posterior_transform: Optional[PosteriorTransform] = None,
        maximize: bool = True,
        **kwargs,
    ) -> None:
        r"""Single-outcome mean-square error.

        Args:
            model: A fitted single-outcome GP model (must be in batch mode if
                candidate sets X will be)
            posterior_transform: A PosteriorTransform. If using a multi-output model,
                a PosteriorTransform that transforms the multi-output posterior into a
                single-output posterior is required.
            maximize: If True, consider the problem a maximization problem.
        """
        super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
        self.maximize = maximize

    @t_batch_mode_transform(expected_q=1)
    def forward(self, X: Tensor) -> Tensor:
        r"""Evaluate the Upper Confidence Bound on the candidate set X.

        Args:
            X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.

        Returns:
            A `(b1 x ... bk)`-dim tensor of MSE values at the given design points `X`.
        """
        posterior = self.model.posterior(
            X=X, posterior_transform=self.posterior_transform
        )
        mean = posterior.mean
        view_shape = mean.shape[:-2] if mean.shape[-2] == 1 else mean.shape[:-1]
        mean = mean.view(view_shape)
        variance = posterior.variance.view(view_shape)
        if self.maximize:
            return variance
        else:
            return variance

# --------------------------------------------------------------------------- #
# Two-dimensional Gaussian peak function
# --------------------------------------------------------------------------- #
def peak_func(x):
    """Two-dimensional Gaussian peak function.

    x.shape: (N, 2)
    """
    ai = torch.tensor([5., 5.], dtype=dtype)
    ui = torch.tensor([0.8, 0.8], dtype=dtype)

    _peak_func = torch.exp(
        -torch.sum(ai**2 * (x - ui)**2, axis=1, keepdims=True))

    return _peak_func

def genz_func(x):
    """Non-smooth test function suggested by Genz (1984).

    x.shape: (N, 2)
    """
    _genz_func = 1 / (torch.abs(0.5 - x[:, 0]**4 - x[:, 1]**4) + 0.1)[:, None]

    return _genz_func

# Define the test function on [0, 1]^2
xlimits = np.array([[0, 1], [0, 1]])
# Latin-Hypercube sampling
sampler_LHS = LHS(xlimits=xlimits, random_state=123)
# Uniform distribution
sampler_uniform = Random(xlimits=xlimits)
sampler_list = {'LHS': sampler_LHS, 'uniform': sampler_uniform}

# Uniformly distributed test dataset
N_test_X = 512
test_X = torch.tensor(sampler_list['uniform'](N_test_X))
test_y = genz_func(test_X)

# # Plot
# fig = plt.figure(figsize=(9, 6))
# ax = fig.add_subplot(projection='3d')
# ax.scatter(test_X[:, 0], test_X[:, 1], test_y)
# ax.set_xlim([0, 1])
# ax.set_ylim([0, 1])
# ax.set_xlabel(r'$x_{1}$')
# ax.set_ylabel(r'$x_{2}$')
# ax.invert_xaxis()
# ax.invert_yaxis()
# plt.savefig('genz_func.pdf', bbox_inches='tight')
# plt.close()

# Botorch optimize_acqf parameters
NUM_RESTARTS = 10
RAW_SAMPLES = 512

# import ipdb; ipdb.set_trace()

# --------------------------------------------------------------------------- #
# Initialize a GP model
# --------------------------------------------------------------------------- #
def initialize_GPmodel(train_X, train_y):
    """Initialize a GP model.

    trian_X.shape: [N, d]
    train_y.shape: [N, 1]

    Return a marginal log likelohood and a GP instance.
    """
    # Get the dimensions of train_X and train_y to standardize inputs/outputs
    train_X_dim = train_X.shape[-1]
    train_y_dim = train_y.shape[-1]

    # Train the GP model with the standardized inputs
    gp = SingleTaskGP(
        train_X, train_y,
        input_transform=InputStandardize(d=train_X_dim),
        outcome_transform=Standardize(m=train_y_dim)
    )
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(gp.likelihood, gp)

    return mll, gp

# --------------------------------------------------------------------------- #
# Compute the leave-one-out error
# --------------------------------------------------------------------------- #
def compute_errLOO(train_X, train_y):
    """Compute the leave-one-out error.

    train_X.shape: (N, d)
    train_y.shape: (N, 1)

    return the leave-ont-out error
    """
    loo = LeaveOneOut()  # Leave-one-out error estimator instance

    err = torch.empty(loo.get_n_splits(train_X), dtype=dtype)

    # Min-Max scale train_y in [0, 1]
    # Min-Max scaler prepserves the original shape of the distribution of train_y
    min_max_scaler = preprocessing.MinMaxScaler()
    train_y_minmax = torch.tensor(min_max_scaler.fit_transform(train_y))

    for train_idx, test_idx in loo.split(train_X):
        train_loo_X, test_loo_X = train_X[train_idx], train_X[test_idx]
        train_loo_y, test_loo_y \
            = train_y_minmax[train_idx], train_y_minmax[test_idx]

        # Train the GP model based on train_loo_X and train_loo_y
        mll_loo, gp_loo = initialize_GPmodel(train_loo_X, train_loo_y)

        # Fit the GP model
        fit_gpytorch_model(mll_loo)

        # Make prediction and compute the squared approximation error
        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            pred_test_loo_X = gp_loo.posterior(test_loo_X)
            err[test_idx] = (test_loo_y - pred_test_loo_X.mean)**2

    # Leave-one-out error
    errLOO = 1 / loo.get_n_splits(train_X) * torch.sum(err)

    return errLOO

# -------------------------------------------------------------------------- #
# Active learning with IMSE and MSE
# -------------------------------------------------------------------------- #
N_train_X_to_add = 5
N_train_X_list = np.arange(100, 201, N_train_X_to_add)
N_train_X_init = N_train_X_list[0]  # Initial training size

# --------------------------------------------------------------------------- #
print(r'Active learning with the integrated mean square error criterion (IMSE)')
# --------------------------------------------------------------------------- #
# Store errLOO errors
errLOO_IMSE = np.empty(len(N_train_X_list))
# Number of added training sets
N_train_X_IMSE = np.empty_like(errLOO_IMSE)

# Initial training set sampling from LHS
train_X = torch.tensor(sampler_list['LHS'](N_train_X_init))
train_y = genz_func(train_X)
# Initialize a GP model
mll, gp = initialize_GPmodel(train_X, train_y)
# Fit the pilot GP model
fit_gpytorch_model(mll)

# Bounds on train_X
bounds = torch.tensor(
    [[0.] * train_X.shape[1], [1.] * train_X.shape[1]], dtype=dtype)

# Active learning with IMSE
for imse_iter_idx in range(len(N_train_X_list)):

    if imse_iter_idx == 0:
        print(r"Compute errLOO for a pilot GP model...")
        # Compute LOO error when we add N_train_X_init number of samples
        errLOO_IMSE[imse_iter_idx] = compute_errLOO(train_X, train_y)
        N_train_X_IMSE[imse_iter_idx] = train_X.shape[0]
    else:
        print(r"Actively learning {} new training data in total...".format(
            imse_iter_idx * N_train_X_to_add))

        # We are sequantially add N_train_X_to_add new training data
        for iter_idx in range(N_train_X_to_add):
            # Quasi-MC base sampler using Sobol sequences
            qmc_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1]), seed=123)

            with torch.no_grad(), gpytorch.settings.fast_pred_var():
                # As we standardized inputs, we must use gp.posterior to get
                # rescaled outputs.
                pred_test_X = gp.posterior(test_X)

            # The points to use for MC-integrating the posterior variance
            qmc_samples = qmc_sampler(pred_test_X)

            # Batch Integrated Negative Posterior Variance for Active Learning
            # mc_points are used for MC-integrating the posterior variance.
            qNIPV = qNegIntegratedPosteriorVariance(model=gp, mc_points=qmc_samples)

            # Optimize the acquisition function
            train_X_add, acq_value = optimize_acqf(
                qNIPV, bounds=bounds, q=1, num_restarts=NUM_RESTARTS,
                raw_samples=RAW_SAMPLES
            )

            # Compute the new label
            train_y_add = genz_func(train_X_add)

            # Concatenate the training data and the labels to be added
            train_X = torch.cat([train_X, train_X_add], axis=0)
            train_y = torch.cat([train_y, train_y_add], axis=0)

            # Reinitialize the model so it is ready the next iteration
            mll, gp = initialize_GPmodel(train_X, train_y)
            # Refit the GP model
            fit_gpytorch_model(mll)

        # Compute LOO error when we add N_AL_iter number of samples
        errLOO_IMSE[imse_iter_idx] = compute_errLOO(train_X, train_y)
        N_train_X_IMSE[imse_iter_idx] = train_X.shape[0]

# --------------------------------------------------------------------------- #
print(r'Active learning with the mean square error criterion (MSE)')
# We use the MeanSquareError method inherited from UpperConfidenceBound
# --------------------------------------------------------------------------- #
# Store errLOO errors
errLOO_MSE = np.empty(len(N_train_X_list))
# Number of added training sets
N_train_X_MSE = np.empty_like(errLOO_MSE)

# Initial training set sampling from LHS
train_X = torch.tensor(sampler_list['LHS'](N_train_X_init))
train_y = genz_func(train_X)
# Fit the GP model
mll, gp = initialize_GPmodel(train_X, train_y)
# Bounds on train_X
bounds = torch.tensor(
    [[0.] * train_X.shape[1], [1.] * train_X.shape[1]], dtype=dtype)

# Active learning with MSE
for mse_iter_idx in range(len(N_train_X_list)):

    if mse_iter_idx == 0:
        print(r"Compute errLOO for a pilot GP model...")
        # Compute LOO error when we add N_train_X_init number of samples
        errLOO_MSE[mse_iter_idx] = compute_errLOO(train_X, train_y)
        N_train_X_MSE[mse_iter_idx] = train_X.shape[0]
    else:
        print(r"Actively learning {} new training data in total...".format(
            mse_iter_idx * N_train_X_to_add))

        # We are sequantially add N_train_X_to_add new training data
        for iter_idx in range(N_train_X_to_add):
            # Mean square error
            MSE = MeanSquareError(gp)
            # Optimize the acquisition function
            train_X_add, acq_value = optimize_acqf(
                MSE, bounds=bounds, q=1, num_restarts=NUM_RESTARTS,
                raw_samples=RAW_SAMPLES
            )

            # Compute the new label
            train_y_add = genz_func(train_X_add)

            # Concatenate the training data and the labels to be added
            train_X = torch.cat([train_X, train_X_add], axis=0)
            train_y = torch.cat([train_y, train_y_add], axis=0)

            # Reinitialize the model so it is ready the next iteration
            mll, gp = initialize_GPmodel(train_X, train_y)
            # Refit the GP model
            fit_gpytorch_model(mll)

        # Compute LOO error when we add N_AL_iter number of samples
        errLOO_MSE[mse_iter_idx] = compute_errLOO(train_X, train_y)
        N_train_X_MSE[mse_iter_idx] = train_X.shape[0]

# --------------------------------------------------------------------------- #
print(r'Latin-Hypercube sampling')
# --------------------------------------------------------------------------- #
# Store errLOO errors
errLOO_LHS = np.empty(len(N_train_X_list))
# Number of added training sets
N_train_X_LHS = np.empty_like(errLOO_MSE)

for iter_idx, N_train_X in enumerate(N_train_X_list):
    print("Train a GP with {} samples from LHS".format(N_train_X))
    train_X = torch.tensor(sampler_list['LHS'](N_train_X))
    train_y = genz_func(train_X)
    errLOO_LHS[iter_idx] = compute_errLOO(train_X, train_y)
    N_train_X_LHS[iter_idx] = train_X.shape[0]

# --------------------------------------------------------------------------- #
# Plot errLOO for IMSE and MSE
# --------------------------------------------------------------------------- #
fig, ax = plt.subplots(figsize=(9, 6))

ax.plot(N_train_X_IMSE, errLOO_IMSE, color='red', marker='o', label='IMSE')
ax.plot(N_train_X_MSE, errLOO_MSE, color='blue', marker='o', label='MSE')
ax.plot(N_train_X_LHS, errLOO_LHS, color='k', linestyle='--',  marker='o',
        label='LHS')
ax.set_xlim([N_train_X_list.min(), N_train_X_list.max()])
ax.set_xlabel(r'Number of experimental design')
ax.set_ylabel(r'$\epsilon_{\mathrm{LOO}}$')
ax.set_yscale('log')
ax.legend(loc='upper right')
# plt.show()
plt.savefig('errLOO_mse_imse_LHS.png', bbox_inches='tight')
plt.close()

The leave-one-out error decreases when the number of experimental design increases, and IMSE criterion outperforms MSE.

I compared with the Latin-Hypercube sampling. LHS sometimes outperforms IMSE and MSE criteria, but it is very unstable. We can observe stable decreasing trends in the error metric with IMSE and MSE, which is another asset of active learning.

eytan commented 1 year ago

My intuition is that it is hard to beat space-filling designs for low-dimensional inputs, and active learning is going to contribute more value once you start hitting > 4 dimensions.

On Fri, Jan 20, 2023 at 2:59 PM Takafumi Usui @.***> wrote:

I have follow-up comments on my previous post. My colleague told me that, in practical applications, the leave-one-out error of less than 10^{-2} could be sufficient. In the previous post, as my test function is close to 0 in some domains, the leave-one-out error is already very small. In this case, sequential refining does not work as we expect.

I tested with another function. The function is defined on [0, 1]^2. I scale the output to [0, 1] when computing the leave-one-out error.

import numpy as npfrom smt.sampling_methods import LHS, Randomimport matplotlib.pyplot as plt

PyTorchimport torch# GPyTorchimport gpytorch# BoTorchfrom botorch.models import SingleTaskGPfrom botorch.fit import fit_gpytorch_modelfrom botorch.models.transforms.outcome import Standardizefrom botorch.models.transforms.input import InputStandardizefrom botorch.acquisition import qNegIntegratedPosteriorVariancefrom botorch.sampling.normal import SobolQMCNormalSamplerfrom botorch.optim import optimize_acqf

Leave-one-out errorfrom sklearn.model_selection import LeaveOneOut# Pre-processingfrom sklearn import preprocessing

Fix random seedsnp.random.seed(123)torch.manual_seed(123)

Use double data typedtype = torch.double

from botorch.acquisition.analytic import AnalyticAcquisitionFunctionfrom botorch.models.model import Modelfrom botorch.acquisition.objective import PosteriorTransformfrom botorch.utils.transforms import t_batch_mode_transformfrom typing import Optionalfrom torch import Tensor

class MeanSquareError(AnalyticAcquisitionFunction): r"""Single-outcome mean-square error (MSE). Analytic mean square error that focuses on pure-exploration. The acquasition function focuses on the posterior variance at the query points. Only supports the case of q=1 (i.e. greedy, non-batch selection of design points). The model must be single-outcome. MSE(x) = sigma(x), where 'sigma` is the posterior standard deviation. """

def __init__(
    self,
    model: Model,
    posterior_transform: Optional[PosteriorTransform] = None,
    maximize: bool = True,
    **kwargs,
) -> None:
    r"""Single-outcome mean-square error.        Args:            model: A fitted single-outcome GP model (must be in batch mode if                candidate sets X will be)            posterior_transform: A PosteriorTransform. If using a multi-output model,                a PosteriorTransform that transforms the multi-output posterior into a                single-output posterior is required.            maximize: If True, consider the problem a maximization problem.        """
    super().__init__(model=model, posterior_transform=posterior_transform, **kwargs)
    self.maximize = maximize

@t_batch_mode_transform(expected_q=1)
def forward(self, X: Tensor) -> Tensor:
    r"""Evaluate the Upper Confidence Bound on the candidate set X.        Args:            X: A `(b1 x ... bk) x 1 x d`-dim batched tensor of `d`-dim design points.        Returns:            A `(b1 x ... bk)`-dim tensor of MSE values at the given design points `X`.        """
    posterior = self.model.posterior(
        X=X, posterior_transform=self.posterior_transform
    )
    mean = posterior.mean
    view_shape = mean.shape[:-2] if mean.shape[-2] == 1 else mean.shape[:-1]
    mean = mean.view(view_shape)
    variance = posterior.variance.view(view_shape)
    if self.maximize:
        return variance
    else:
        return variance

--------------------------------------------------------------------------- ## Two-dimensional Gaussian peak function# --------------------------------------------------------------------------- #def peak_func(x):

"""Two-dimensional Gaussian peak function.    x.shape: (N, 2)    """
ai = torch.tensor([5., 5.], dtype=dtype)
ui = torch.tensor([0.8, 0.8], dtype=dtype)

_peak_func = torch.exp(
    -torch.sum(ai**2 * (x - ui)**2, axis=1, keepdims=True))

return _peak_func

def genz_func(x): """Non-smooth test function suggested by Genz (1984). x.shape: (N, 2) """ _genz_func = 1 / (torch.abs(0.5 - x[:, 0]4 - x[:, 1]4) + 0.1)[:, None]

return _genz_func

Define the test function on [0, 1]^2xlimits = np.array([[0, 1], [0, 1]])# Latin-Hypercube samplingsampler_LHS = LHS(xlimits=xlimits, random_state=123)# Uniform distributionsampler_uniform = Random(xlimits=xlimits)sampler_list = {'LHS': sampler_LHS, 'uniform': sampler_uniform}

Uniformly distributed test datasetN_test_X = 512test_X = torch.tensor(sampler_list'uniform')test_y = genz_func(test_X)

Plot# fig = plt.figure(figsize=(9, 6))# ax = fig.add_subplot(projection='3d')# ax.scatter(test_X[:, 0], test_X[:, 1], test_y)# ax.set_xlim([0, 1])# ax.set_ylim([0, 1])# ax.setxlabel(r'$x{1}$')# ax.setylabel(r'$x{2}$')# ax.invert_xaxis()# ax.invert_yaxis()# plt.savefig('genz_func.pdf', bbox_inches='tight')# plt.close()

Botorch optimize_acqf parametersNUM_RESTARTS = 10RAW_SAMPLES = 512

import ipdb; ipdb.set_trace()

--------------------------------------------------------------------------- ## Initialize a GP model# --------------------------------------------------------------------------- #def initialize_GPmodel(train_X, train_y):

"""Initialize a GP model.    trian_X.shape: [N, d]    train_y.shape: [N, 1]    Return a marginal log likelohood and a GP instance.    """
# Get the dimensions of train_X and train_y to standardize inputs/outputs
train_X_dim = train_X.shape[-1]
train_y_dim = train_y.shape[-1]

# Train the GP model with the standardized inputs
gp = SingleTaskGP(
    train_X, train_y,
    input_transform=InputStandardize(d=train_X_dim),
    outcome_transform=Standardize(m=train_y_dim)
)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(gp.likelihood, gp)

return mll, gp

--------------------------------------------------------------------------- ## Compute the leave-one-out error# --------------------------------------------------------------------------- #def compute_errLOO(train_X, train_y):

"""Compute the leave-one-out error.    train_X.shape: (N, d)    train_y.shape: (N, 1)    return the leave-ont-out error    """
loo = LeaveOneOut()  # Leave-one-out error estimator instance

err = torch.empty(loo.get_n_splits(train_X), dtype=dtype)

# Min-Max scale train_y in [0, 1]
# Min-Max scaler prepserves the original shape of the distribution of train_y
min_max_scaler = preprocessing.MinMaxScaler()
train_y_minmax = torch.tensor(min_max_scaler.fit_transform(train_y))

for train_idx, test_idx in loo.split(train_X):
    train_loo_X, test_loo_X = train_X[train_idx], train_X[test_idx]
    train_loo_y, test_loo_y \
        = train_y_minmax[train_idx], train_y_minmax[test_idx]

    # Train the GP model based on train_loo_X and train_loo_y
    mll_loo, gp_loo = initialize_GPmodel(train_loo_X, train_loo_y)

    # Fit the GP model
    fit_gpytorch_model(mll_loo)

    # Make prediction and compute the squared approximation error
    with torch.no_grad(), gpytorch.settings.fast_pred_var():
        pred_test_loo_X = gp_loo.posterior(test_loo_X)
        err[test_idx] = (test_loo_y - pred_test_loo_X.mean)**2

# Leave-one-out error
errLOO = 1 / loo.get_n_splits(train_X) * torch.sum(err)

return errLOO

-------------------------------------------------------------------------- ## Active learning with IMSE and MSE# -------------------------------------------------------------------------- #N_train_X_to_add = 5N_train_X_list = np.arange(100, 201, N_train_X_to_add)N_train_X_init = N_train_X_list[0] # Initial training size

--------------------------------------------------------------------------- #print(r'Active learning with the integrated mean square error criterion (IMSE)')# --------------------------------------------------------------------------- ## Store errLOO errorserrLOO_IMSE = np.empty(len(N_train_X_list))# Number of added training setsN_train_X_IMSE = np.empty_like(errLOO_IMSE)

Initial training set sampling from LHStrain_X = torch.tensor(sampler_list'LHS')train_y = genz_func(train_X)# Initialize a GP modelmll, gp = initialize_GPmodel(train_X, train_y)# Fit the pilot GP modelfit_gpytorch_model(mll)

Bounds on train_Xbounds = torch.tensor(

[[0.] * train_X.shape[1], [1.] * train_X.shape[1]], dtype=dtype)

Active learning with IMSEfor imse_iter_idx in range(len(N_train_X_list)):

if imse_iter_idx == 0:
    print(r"Compute errLOO for a pilot GP model...")
    # Compute LOO error when we add N_train_X_init number of samples
    errLOO_IMSE[imse_iter_idx] = compute_errLOO(train_X, train_y)
    N_train_X_IMSE[imse_iter_idx] = train_X.shape[0]
else:
    print(r"Actively learning {} new training data in total...".format(
        imse_iter_idx * N_train_X_to_add))

    # We are sequantially add N_train_X_to_add new training data
    for iter_idx in range(N_train_X_to_add):
        # Quasi-MC base sampler using Sobol sequences
        qmc_sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1]), seed=123)

        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            # As we standardized inputs, we must use gp.posterior to get
            # rescaled outputs.
            pred_test_X = gp.posterior(test_X)

        # The points to use for MC-integrating the posterior variance
        qmc_samples = qmc_sampler(pred_test_X)

        # Batch Integrated Negative Posterior Variance for Active Learning
        # mc_points are used for MC-integrating the posterior variance.
        qNIPV = qNegIntegratedPosteriorVariance(model=gp, mc_points=qmc_samples)

        # Optimize the acquisition function
        train_X_add, acq_value = optimize_acqf(
            qNIPV, bounds=bounds, q=1, num_restarts=NUM_RESTARTS,
            raw_samples=RAW_SAMPLES
        )

        # Compute the new label
        train_y_add = genz_func(train_X_add)

        # Concatenate the training data and the labels to be added
        train_X = torch.cat([train_X, train_X_add], axis=0)
        train_y = torch.cat([train_y, train_y_add], axis=0)

        # Reinitialize the model so it is ready the next iteration
        mll, gp = initialize_GPmodel(train_X, train_y)
        # Refit the GP model
        fit_gpytorch_model(mll)

    # Compute LOO error when we add N_AL_iter number of samples
    errLOO_IMSE[imse_iter_idx] = compute_errLOO(train_X, train_y)
    N_train_X_IMSE[imse_iter_idx] = train_X.shape[0]

--------------------------------------------------------------------------- #print(r'Active learning with the mean square error criterion (MSE)')# We use the MeanSquareError method inherited from UpperConfidenceBound# --------------------------------------------------------------------------- ## Store errLOO errorserrLOO_MSE = np.empty(len(N_train_X_list))# Number of added training setsN_train_X_MSE = np.empty_like(errLOO_MSE)

Initial training set sampling from LHStrain_X = torch.tensor(sampler_list'LHS')train_y = genz_func(train_X)# Fit the GP modelmll, gp = initialize_GPmodel(train_X, train_y)# Bounds on train_Xbounds = torch.tensor(

[[0.] * train_X.shape[1], [1.] * train_X.shape[1]], dtype=dtype)

Active learning with MSEfor mse_iter_idx in range(len(N_train_X_list)):

if mse_iter_idx == 0:
    print(r"Compute errLOO for a pilot GP model...")
    # Compute LOO error when we add N_train_X_init number of samples
    errLOO_MSE[mse_iter_idx] = compute_errLOO(train_X, train_y)
    N_train_X_MSE[mse_iter_idx] = train_X.shape[0]
else:
    print(r"Actively learning {} new training data in total...".format(
        mse_iter_idx * N_train_X_to_add))

    # We are sequantially add N_train_X_to_add new training data
    for iter_idx in range(N_train_X_to_add):
        # Mean square error
        MSE = MeanSquareError(gp)
        # Optimize the acquisition function
        train_X_add, acq_value = optimize_acqf(
            MSE, bounds=bounds, q=1, num_restarts=NUM_RESTARTS,
            raw_samples=RAW_SAMPLES
        )

        # Compute the new label
        train_y_add = genz_func(train_X_add)

        # Concatenate the training data and the labels to be added
        train_X = torch.cat([train_X, train_X_add], axis=0)
        train_y = torch.cat([train_y, train_y_add], axis=0)

        # Reinitialize the model so it is ready the next iteration
        mll, gp = initialize_GPmodel(train_X, train_y)
        # Refit the GP model
        fit_gpytorch_model(mll)

    # Compute LOO error when we add N_AL_iter number of samples
    errLOO_MSE[mse_iter_idx] = compute_errLOO(train_X, train_y)
    N_train_X_MSE[mse_iter_idx] = train_X.shape[0]

--------------------------------------------------------------------------- #print(r'Latin-Hypercube sampling')# --------------------------------------------------------------------------- ## Store errLOO errorserrLOO_LHS = np.empty(len(N_train_X_list))# Number of added training setsN_train_X_LHS = np.empty_like(errLOO_MSE)

for iter_idx, N_train_X in enumerate(N_train_X_list): print("Train a GP with {} samples from LHS".format(N_train_X)) train_X = torch.tensor(sampler_list'LHS') train_y = genz_func(train_X) errLOO_LHS[iter_idx] = compute_errLOO(train_X, train_y) N_train_X_LHS[iter_idx] = train_X.shape[0]

--------------------------------------------------------------------------- ## Plot errLOO for IMSE and MSE# --------------------------------------------------------------------------- #fig, ax = plt.subplots(figsize=(9, 6))

ax.plot(N_train_X_IMSE, errLOO_IMSE, color='red', marker='o', label='IMSE')ax.plot(N_train_X_MSE, errLOO_MSE, color='blue', marker='o', label='MSE')ax.plot(N_train_X_LHS, errLOO_LHS, color='k', linestyle='--', marker='o', label='LHS')ax.set_xlim([N_train_X_list.min(), N_train_X_list.max()])ax.set_xlabel(r'Number of experimental design')ax.setylabel(r'$\epsilon{\mathrm{LOO}}$')ax.set_yscale('log')ax.legend(loc='upper right')# plt.show()plt.savefig('errLOO_mse_imse_LHS.png', bbox_inches='tight')plt.close()

The leave-one-out error decreases when the number of experimental design increases, and IMSE criterion outperforms MSE.

https://user-images.githubusercontent.com/11461584/213793297-0c78a964-4e5b-46d0-9c14-c859f51917ef.png

I compared with the Latin-Hypercube sampling. LHS sometimes outperforms IMSE and MSE criteria, but it is very unstable. We can observe stable decreasing trends in the error metric with IMSE and MSE, which is another asset of active learning.

— Reply to this email directly, view it on GitHub https://github.com/pytorch/botorch/issues/1366#issuecomment-1398865898, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAAW34KPQ3DEJT4IDPBKUQLWTLVBRANCNFSM57QEDASA . You are receiving this because you were mentioned.Message ID: @.***>

Balandat commented 1 year ago

Looks like you're just running a single loop here? How much variance is in these results? Have you tried to do this for a number of replicates with different seeds and looked at the distribution of the results? It may just be that you're just looking at noise (or the realization of a particular seed) right now.

takafusui commented 1 year ago

Thank you very much for your comments.

@eytan it is a very good point indeed. Do you have in your mind a popular test function with which we can measure a benchmark of each method in the active learning community?

@Balandat yes I fixed random seeds and do only one loop to check how the error metric decreased. So you suggested that I make a loop where we do the same thing with different realizations of random variables and check the robustness?

Balandat commented 1 year ago

So you suggested that I make a loop where we do the same thing with different realizations of random variables and check the robustness?

Yes

takafusui commented 1 year ago

I would like to know the general guidelines for choosing some parameters when using BoTorch, especially when we address relatively high dimensional (d > 7 or more) input space. I have observed that, for instance, when I set a bigger num_restarts, I have a memory issue (I am using 16GB RAM).

optimize_acqf has two parameters namely num_restarts and raw_samples. My understanding is that when we increase raw_samples, we could get better initial conditions and improve the subsequent optimization of an acquisition function. When we increase num_restarts, we optimize an acquisition function from num_restarts different initial conditions, and we have a better chance to optimize a non-convex acquisition function. So how can I balance num_restarts and raw_samples?

'draw_sobol_samples' requires n, q, and batch_shape. I naively think we should assign a big n as we need to integrate variance for a high-dimensional input space. What is q (q-batch)? What should we do with batch_shape? I usually set batch_shape=1.

qNegIntegratedPosteriorVariance requires mc_points whose shape is batch_shape x N x d. mc_points should be qMC, so I use sample points from draw_sobol_samples. I guess the choice of n and batch_shape in draw_sobol_samples might also be crucial here.

Thank you again for your help. I understand that these choices depend on what GP model we are now dealing with, but I would like to know how we should choose them to avoid a memory error...

saitcakmak commented 1 year ago

In Ax, we use num_restarts=20, raw_samples=1024 regardless of the input dimension. It is not a bad idea to increase these as the input dimension increases, but it comes at a computational cost. To reduce the memory usage, you can pass options={"batch_limit": <some number smaller than num_restarts>, "init_batch_limit": <some number smaller than raw_samples>} to optimize_acqf. This will split the acquisition function evaluation into mini batches, reducing the peak memory usage proportionally. If you're using an EI based acquisition function, passing in options={"sample_around_best": True} might also help improve the optimization performance. This will sample additional raw samples around the best observed point.

For larger dimensional inputs, we use models with sparsity inducing priors, which lead to much more accurate models. SaasFullyBayesianSingleTaskGP is available in BoTorch. Note that the fully Bayesian inference used by this model makes it much more expensive to train.

'draw_sobol_samples' requires n, q, and batch_shape. I naively think we should assign a big n as we need to integrate variance for a high-dimensional input space. What is q (q-batch)? What should we do with batch_shape? I usually set batch_shape=1.

You can ignore the batch_shape unless you need samples of batch x n x q x d. Under the hood, this will sample batch_shape * n samples of q * d dimension and reshape those. The q-batch is relevant, e.g., when optimizing an acquisition function for parallel evaluations (optimize_acqf with q>1). Otherwise, you should always use n to specify the number of samples you need. For Sobol, n should be a power of 2 for best results.

qNegIntegratedPosteriorVariance requires mc_points whose shape is batch_shape x N x d. mc_points should be qMC, so I use sample points from draw_sobol_samples. I guess the choice of n and batch_shape in draw_sobol_samples might also be crucial here.

I believe it is safe to just provide mc_points of shape n x d there. The same samples will be used for all Xs while evaluating the acquisition function, which is perfectly fine. Larger N would lead to more precise integration of the posterior variance across the search space. Since GP posterior variance is a relatively smooth function, I wouldn't expect pushing N larger to make a big difference.

saitcakmak commented 6 months ago

Closing as inactive & original issue resolved by https://github.com/pytorch/botorch/pull/2060