pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.11k stars 406 forks source link

[Bug] `fixed_features` does not support negative indices #2602

Closed slishak-PX closed 3 weeks ago

slishak-PX commented 3 weeks ago

🐛 Bug

If fixed_features has a negative index, the initial conditions will not be constructed with the correct reduced dimensionality.

To reproduce

Code snippet to reproduce

import torch
from botorch.acquisition import qLogExpectedImprovement
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize
from botorch.optim import optimize_acqf
from gpytorch.mlls import ExactMarginalLogLikelihood

n_inputs = 4
n_outputs = 1
n_train = 256
n_test = 16
device = torch.device("cpu")

train_x = torch.rand(n_train, n_inputs, dtype=torch.float64, device=device)
train_y = torch.randn(n_train, n_outputs, dtype=torch.float64, device=device)

model = SingleTaskGP(train_x, train_y, input_transform=Normalize(n_inputs))

mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)

acqf = qLogExpectedImprovement(model, best_f=train_y.max())

bounds = torch.vstack([torch.zeros(1, n_inputs), torch.ones(1, n_inputs)])

candidates, value = optimize_acqf(
    acqf,
    bounds,
    q=2,
    num_restarts=4,
    raw_samples=8,
    fixed_features={-1: 0.5},
)

Stack trace/error message

Traceback (most recent call last):
  File "repro.py", line 27, in <module>
    candidates, value = optimize_acqf(
  File ".../lib/python3.10/site-packages/botorch/optim/optimize.py", line 547, in optimize_acqf
    return _optimize_acqf(opt_acqf_inputs)
  File ".../lib/python3.10/site-packages/botorch/optim/optimize.py", line 568, in _optimize_acqf
    return _optimize_acqf_batch(opt_inputs=opt_inputs)
  File ".../lib/python3.10/site-packages/botorch/optim/optimize.py", line 332, in _optimize_acqf_batch
    batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
  File ".../lib/python3.10/site-packages/botorch/optim/optimize.py", line 316, in _optimize_batch_candidates
    ) = opt_inputs.gen_candidates(
  File ".../lib/python3.10/site-packages/botorch/generation/gen.py", line 159, in gen_candidates_scipy
    clamped_candidates, batch_acquisition = gen_candidates_scipy(
  File ".../lib/python3.10/site-packages/botorch/generation/gen.py", line 251, in gen_candidates_scipy
    res = minimize_with_timeout(
  File ".../lib/python3.10/site-packages/botorch/optim/utils/timeout.py", line 83, in minimize_with_timeout
    return optimize.minimize(
  File ".../lib/python3.10/site-packages/scipy/optimize/_minimize.py", line 731, in minimize
    res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
  File ".../lib/python3.10/site-packages/scipy/optimize/_lbfgsb_py.py", line 347, in _minimize_lbfgsb
    sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
  File ".../lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 288, in _prepare_scalar_function
    sf = ScalarFunction(fun, x0, args, grad, hess,
  File ".../lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 222, in __init__
    self._update_fun()
  File ".../lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 294, in _update_fun
    fx = self._wrapped_fun(self.x)
  File ".../lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 20, in wrapped
    fx = fun(np.copy(x), *args)
  File ".../lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 79, in __call__
    self._compute_if_needed(x, *args)
  File ".../lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 73, in _compute_if_needed
    fg = self.fun(x, *args)
  File ".../lib/python3.10/site-packages/botorch/generation/gen.py", line 208, in f_np_wrapper
    loss = f(X_fix).sum()
  File ".../lib/python3.10/site-packages/botorch/generation/gen.py", line 249, in f
    return -acquisition_function(x)
  File ".../lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File ".../lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File ".../lib/python3.10/site-packages/botorch/acquisition/fixed_feature.py", line 155, in forward
    X_full = self._construct_X_full(X)
  File ".../lib/python3.10/site-packages/botorch/acquisition/fixed_feature.py", line 192, in _construct_X_full
    raise ValueError(
ValueError: Feature dimension d' (4) of input must be d - d_f (3).

Expected Behavior

The negative indices should be equivalent to specifying the last dimension, or the docs should make it clear that this is not allowed.

System information

Please complete the following information:

Additional context

This line is one reason why it doesn't work - I don't know yet if there are other areas where the indices are assumed positive.

https://github.com/pytorch/botorch/blob/66660e341b7dd0780feac4640f3709a8fd024206/botorch/generation/utils.py#L164

esantorella commented 3 weeks ago

Thanks for reporting. I put in #2603 to update docstrings to clarify that negative indices are not allowed and raise an exception if they are provided.

Balandat commented 3 weeks ago

I feel like it shouldn't be to hard to allow this by canonicalizing the indices to index % num_features? Maybe we can make a backlog task for this?

esantorella commented 3 weeks ago

To keep things organized, I opened #2605 for that feature request. It would be a good task for a newcomer.