pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.06k stars 390 forks source link

[Bug] OrthogonalAdditiveKernel doesn't work with input transforms because they generate x values outside the unit hypercube #2270

Open esantorella opened 5 months ago

esantorella commented 5 months ago

🐛 Bug

OrthogonalAdditiveKernel will error here if provided x values outside the unit hypercube, [0, 1]^d. Unfortunately, when combining this kernel with basic BoTorch functionality, it is hard to avoid passing such values. For example, if the search space is [0, 1], a model is trained on points ranging from 0.25 to 0.75, and a Normalize input transform is used, then 0 and 1 will transform to -1 and 2 and lie outside the hypercube.

To reproduce

Code snippet to reproduce

from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel

from botorch.models.gp_regression import SingleTaskGP, get_matern_kernel_with_gamma_prior
from botorch.fit import fit_gpytorch_mll
from gpytorch.mlls import ExactMarginalLogLikelihood
from botorch.models.transforms.outcome import Standardize
from botorch.models.transforms.input import Normalize

import torch
from botorch.acquisition.logei import qLogNoisyExpectedImprovement

train_X = torch.tensor([[0.3], [0.7]], dtype=torch.float64)
train_Y = torch.tensor([[0.3], [0.7]], dtype=torch.float64)

kernel = OrthogonalAdditiveKernel(
    base_kernel=get_matern_kernel_with_gamma_prior(
        ard_num_dims=None,
    ),
    dim=1,
    dtype=torch.double,
)

model = SingleTaskGP(
    train_X=train_X,
    train_Y=train_Y,
    input_transform=Normalize(d=1),
    outcome_transform=Standardize(m=1),
    covar_module=kernel
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
_ = fit_gpytorch_mll(mll)

model.posterior(train_X)  # works

# errors
model.posterior(torch.tensor([[0.2], [0.8]]), dtype=torch.float64)

Stack trace/error message

Traceback (most recent call last):
  File "/Users/lizs/oak_issue.py", line 36, in <module>
    model.posterior(torch.tensor([[0.2], [0.8]]), dtype=torch.float64)
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/gpytorch.py", line 388, in posterior
    mvn = self(X)
          ^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_gp.py", line 333, in __call__
    ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_prediction_strategies.py", line 281, in exact_prediction
    test_covar = joint_covar[..., self.num_train :, :].to_dense()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 410, in to_dense
    return self.evaluate_kernel().to_dense()
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 25, in wrapped
    output = method(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
    res = self.kernel(
          ^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/kernels/kernel.py", line 530, in __call__
    super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/module.py", line 31, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 163, in forward
    K_ortho = self._orthogonal_base_kernels(x1, x2)  # batch_shape x d x n1 x n2
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 202, in _orthogonal_base_kernels
    _check_hypercube(x1, "x1")
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 270, in _check_hypercube
    raise ValueError(name + " is not in hypercube [0, 1]^d.")
ValueError: x1 is not in hypercube [0, 1]^d.
acqf = qLogNoisyExpectedImprovement(
    model,
    X_baseline=train_X,
)
optimize_acqf(
    acqf,
    bounds=torch.tensor([[0.0], [1.0]]),
    q=1,
    num_restarts=16,
    raw_samples=32,
)
Traceback (most recent call last):
  File "/Users/lizs/oak_issue.py", line 43, in <module>
    optimize_acqf(
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/optimize.py", line 563, in optimize_acqf
    return _optimize_acqf(opt_acqf_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/optimize.py", line 584, in _optimize_acqf
    return _optimize_acqf_batch(opt_inputs=opt_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/optimize.py", line 274, in _optimize_acqf_batch
    batch_initial_conditions = opt_inputs.get_ic_generator()(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/optim/initializers.py", line 417, in gen_batch_initial_conditions
    Y_rnd_curr = acq_function(
                 ^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/utils/transforms.py", line 305, in decorated
    return method(cls, X, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/utils/transforms.py", line 259, in decorated
    output = method(acqf, X, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/acquisition/monte_carlo.py", line 274, in forward
    non_reduced_acqval = self._non_reduced_forward(X=X)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/acquisition/monte_carlo.py", line 287, in _non_reduced_forward
    samples, obj = self._get_samples_and_objectives(X)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/acquisition/logei.py", line 465, in _get_samples_and_objectives
    posterior = self.model.posterior(
                ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/gpytorch.py", line 388, in posterior
    mvn = self(X)
          ^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_gp.py", line 333, in __call__
    ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/models/exact_prediction_strategies.py", line 281, in exact_prediction
    test_covar = joint_covar[..., self.num_train :, :].to_dense()
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 410, in to_dense
    return self.evaluate_kernel().to_dense()
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 25, in wrapped
    output = method(self, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 355, in evaluate_kernel
    res = self.kernel(
          ^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/kernels/kernel.py", line 530, in __call__
    super(Kernel, self).__call__(x1_, x2_, last_dim_is_batch=last_dim_is_batch, **params)
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/gpytorch/module.py", line 31, in __call__
    outputs = self.forward(*inputs, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 163, in forward
    K_ortho = self._orthogonal_base_kernels(x1, x2)  # batch_shape x d x n1 x n2
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 202, in _orthogonal_base_kernels
    _check_hypercube(x1, "x1")
  File "/opt/miniconda3/envs/ax/lib/python3.12/site-packages/botorch/models/kernels/orthogonal_additive_kernel.py", line 270, in _check_hypercube
    raise ValueError(name + " is not in hypercube [0, 1]^d.")
ValueError: x1 is not in hypercube [0, 1]^d.

Expected Behavior

This should not error. Currently, the only way to use the OAK is to manually normalize the search space (rather than the training data) to [0, 1], which is not documented or well-supported.

System information

Please complete the following information:

Balandat commented 5 months ago

@SebastianAment is the requirement that all inputs are contained in the unit cube critical for this kernel?

SebastianAment commented 5 months ago

Thanks for raising this. I added this check to ensure that the search space bounds are passed to Normalize, otherwise the orthogonality condition can only be guaranteed on the training set. In the example above, passing Normalize(d=1, bounds=bounds) would work. We can add this to the error message.

@SebastianAment is the requirement that all inputs are contained in the unit cube critical for this kernel?

In principle we could also open the kernel up to be evaluated outside of the orthogonality domain, but I think it's better to error out in this case, at least by default, as orthogonality is the defining property that users would expect from the kernel.

Balandat commented 2 weeks ago

cc @hvarfner