Closed matthewcarbone closed 1 year ago
Hi @matthewcarbone - you found one of the secret features that are not well supported at this point. We hope to improve this and add more validation and documentation but haven't had the bandwidth to do that yet.
As to what ic_generator
is - Basically, acquisition function optimization in botorch (in optimize_acqf
) uses multi-start gradient ascent from multiple initial conditions in order to deal with nonconvexities / local maxima in the response surface of acquisition functions. You can either provide those ICs (via the batch_initial_conditions
arg) or let botorch choose a default initialization heuristic (in botorch this is based on doing a batch evaluation of the acquisition function at a large number of random points and then performing Boltzmann sampling on the normalized values, see [gen_batch_initial_conditions
] (https://github.com/pytorch/botorch/blob/main/botorch/optim/initializers.py#L50).
In the presence of nonlinear constraints sampling from the feasible set becomes a very hard problem to solve in a general fashion, so this is currently unsupported in the default initializer. You have two options: (i) pass as set of initial conditions to perform multi-start ascent from as batch_initial_conditions
, or (ii) write your own initializer callable with the same signature as gen_batch_initial_conditions
that implements a strategy for generating such initial conditions and pass that in via a ic_generator
kwarg to optimize_acqf
.
cc @dme65
Hi @matthewcarbone - you found one of the secret features that are not well supported at this point. We hope to improve this and add more validation and documentation but haven't had the bandwidth to do that yet.
@Balandat no worries! I get it π
So just in summary: we can get away with using a "sensible default" for the initial conditions in optimize_acqf
, but this is left open to the user to provide explicitly when the constraints are non-linear?
Ultimately I just want to try and get away with a sensible default here. It seems like gen_batch_initial_conditions
takes many of the same arguments as optmize_acqf
. Can you show me a minimal example of how I would call optimize_acqf
with a "sensible default" argument for ic_generator
? Sorry I know this is annoying, but I probably just need to see one example, then I can explore from there.
Thanks!
Also quick followup: as an alternative, we could in principle consider using the penalty function to implement this nonlinear constraint as a "soft constraint", but that does not seem to be the "right way" to do this. Is this correct? I'm not even sure how I would implement that in principle... Send the acquisition function to -inf
when the non-linear constraint is not satisfied? Unsure if that will lead to numerical instability issues or not... would love to hear your comments!
So just in summary: we can get away with using a "sensible default" for the initial conditions in optimize_acqf, but this is left open to the user to provide explicitly when the constraints are non-linear?
Correct.
Can you show me a minimal example of how I would call optimize_acqf with a "sensible default" argument for ic_generator?
Basically you'll have to implement something like this
def gen_batch_initial_conditions_nonlinear(
acq_function: AcquisitionFunction,
bounds: Tensor,
q: int,
num_restarts: int,
raw_samples: int,
fixed_features: Optional[Dict[int, float]] = None,
options: Optional[Dict[str, Union[bool, float, int]]] = None,
inequality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
equality_constraints: Optional[List[Tuple[Tensor, Tensor, float]]] = None,
) -> Tensor:
# your code for generating a `num_restarts x q x d` tensor of
# feasible initial conditions (s.t. bounds and nonlinear constraints) goes here.
and then call optimize_acqf
as
optimize_acqf(
acq_function=...,
...,
ic_generator=gen_batch_initial_conditions_nonlinear,
)
You can look at gen_batch_initial_conditions
for inspiration. One immediate thing you could try to do is rejection sampling - generate points randomly in the feasible polytope (bounds + linear constraints) and then reject points not satisfying the constraints. You'd stick this somewhere around here: https://github.com/pytorch/botorch/blob/76062a65ac86df3c91560f6d243b02609170585c/botorch/optim/initializers.py#L185. The main issue with this is of course that if the volume of the feasible set is small relative to the constraint polytope you may have to sample and reject for a very long time.
Also quick followup: as an alternative, we could in principle consider using the penalty function to implement this nonlinear constraint as a "soft constraint", but that does not seem to be the "right way" to do this. Is this correct?
That would also be a possibility, albeit not a great one. You could add a penalty term; inf
is going to cause issues but you could use a kind of barrier function approach. This could be hooked up via the PenalizedAcquisitionFunction
abstraction. The issue with that is you'll have to think hard about how to do this optimization properly and probably do some kind of homotopy continuation on the barrier function parameters that make them progressively steeper as you go in order to get both precise approximations and reasonable convergence and stability. So better to have the optimizer deal with that directly.
@Balandat ok thanks for the feedback! I agree with you about the soft constraints, seems worth it to pay the price of rejection sampling.
To build intuition, I'm looking for the simplest possible thing I can do. It would seem that this would be something like (this is Python pseudo code and not expected to work out of the box!):
gen_batch_initial_conditions_nonlinear(non_linear_constraints, **kwargs):
total_points = np.array([])
right_shape = False
while not right_shape:
points = gen_batch_initial_conditions(**kwargs)
points = rejection_sampler(non_linear_constraints, points)
total_points.append(points)
if total_points.shape is right_shape:
break
return total_points
where basically I can just call the initial conditions method, apply the non-linear constraint and just continue to sample until I build enough of a set of points to be consistent with the provided expected shape.
Does this make sense?
@Balandat I've taken a shot at this but I'd like your feedback. Using simple rejection sampling, this seems to work. It'd be helpful to have your eyes on it because I'm not sure if I'm messing with/biasing the sampling procedure somehow by doing all of the reshaping.
Constraint (slightly different than before):
def constraint(x):
return torch.abs(x[..., 0] - x[..., 1]) >= 0.5
Initializer:
def gen_batch_initial_conditions_nonlinear(
acq_function,
bounds,
q,
num_restarts,
raw_samples,
**kwargs
):
# your code for generating a `num_restarts x q x d` tensor of
# feasible initial conditions (s.t. bounds and nonlinear constraints) goes here.
need = num_restarts * q
all_points = torch.tensor([])
ndims = bounds.shape[1]
while all_points.shape[0] <= need:
# Using this as the initial guess seemed reasonable...
res = gen_batch_initial_conditions(
acq_function=acq_function,
bounds=bounds,
q=q,
num_restarts=num_restarts,
raw_samples=raw_samples,
**kwargs
)
res = res.reshape(-1, ndims)
where = torch.where(constraint(res))[0]
all_points = torch.cat([all_points, res[where, :]], axis=0)
return all_points[:need, :].reshape(num_restarts, q, ndims)
Result looks reasonable (flattening the first two axes):
Any thoughts? Really appreciate your help. Thanks!
I think code-wise and re-shape wise this looks reasonable to me. Choosing gen_batch_initial_conditions
as the initial guess inside the rejection sampling loop though could be problematic. Suppose the acquisition function over the full domain has a clear and isolated peak in the region that becomes infeasible under the nonlinear constraint. Then gen_batch_initial_conditions
is likely to only ever generate infeasible suggestions and you'll loop indefinitely.
I would recommend first doing rejection sampling on the constraint, and then using those inside of gen_batch_initial_conditions
instead of the random sampling done there (we can factor out that logic from gen_batch_initial_conditions
into a helper function if that can be reused elsewhere).
@Balandat ok so if I understand correctly, you're basically saying I should rewrite the interior logic of gen_batch_initial_conditions
to do the rejection sampling first on the constraint? I guess I'll need to take a close look at how that function works. Thanks!
Correct! Essentially you'll need to swap this part out for your rejection sampling: https://github.com/pytorch/botorch/blob/main/botorch/optim/initializers.py#L142-L166, the rest of the logic in that function is applying some heuristics to pick the most promising of these points. You could keep the X_rnd
and just apply rection sampling according to your constraints on that (somewhere around here: https://github.com/pytorch/botorch/blob/main/botorch/optim/initializers.py#LL167C12-L167C12) and leave the rest mostly as is. We should probably just factor out this sample generation into a helper function that you can swap out wihtout having to rewrite the whole thing.
@Balandat ok I think I understand. Here's my plan. First, I will host my version of this function (with copyright attribution obviously, as per the MIT license terms) or submodule a botorch
fork. I'll add an elif
statement akin to this https://github.com/pytorch/botorch/blob/cbd5002fdca43411211f98392043bc3c46872153/botorch/optim/initializers.py#L154-L166 specifically to catch when there are non-linear constraints present. Then I'll loop over the contents of the above contents in the else
block, i.e. calling get_polytope_samples
(perhaps by iterating the seed by 1 each time to get new samples?) until I have enough samples after doing the non-linear constraints/rejections on the candidates. Does this make sense? I'll post/link it when I'm done.
We should probably just factor out this sample generation into a helper function that you can swap out wihtout having to rewrite the whole thing.
This would be great.
I mentioned in another issue I would love to contribute to botorch
but I'm wary of "signing" the Meta terms to contribute and how that could possibly conflict with my position.
That makes sense, looking forward to learning what you find!
@Balandat while we discuss the whole CLA stuff over email, please find the modified code which does what I think it should do here. I use the helper function on line 271 here to return a function just like gen_batch_initial_conditions
, but with the non-linear constraint pre-programed in. That way, it has an identical signature to the usual function.
Unfortunately, this does not seem to work... I have setup an example notebook for you here (it says pure botorch in the title but you'll need to clone my repo to use some helper functions) so you can see exactly what I mean. In the first example without the non-linear constraints, we have a simple active learning (almost, it's just UCB with $\beta=100$) loop which works just fine.
In the second example, I make only slight modifications to the working loop and try to add in the constraints. However, I am met with this nasty error that I cannot seem to figure out:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Input In [62], in <cell line: 6>()
23 # Ask
24 acq_function = UpperConfidenceBound(model, beta=100.0)
---> 25 new_points, acq_value = optimize_acqf(
26 acq_function,
27 bounds=torch.tensor(bounds).T,
28 q=1,
29 num_restarts=5,
30 raw_samples=20,
31 nonlinear_inequality_constraints=[constraint],
32 ic_generator=gen_batch_initial_conditions_nonlinear,
33 )
35 # Get the current observation. Here, `truth` will have to be implemented in a real experiment!
36 current_obs = truth(new_points).reshape(-1, 1)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/optimize.py:288, in optimize_acqf(acq_function, bounds, q, num_restarts, raw_samples, options, inequality_constraints, equality_constraints, nonlinear_inequality_constraints, fixed_features, post_processing_func, batch_initial_conditions, return_best_only, sequential, **kwargs)
285 batch_acq_values = torch.cat(batch_acq_values_list)
286 return batch_candidates, batch_acq_values, opt_warnings
--> 288 batch_candidates, batch_acq_values, ws = _optimize_batch_candidates()
290 optimization_warning_raised = any(
291 (issubclass(w.category, OptimizationWarning) for w in ws)
292 )
293 if optimization_warning_raised:
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/optimize.py:276, in optimize_acqf.<locals>._optimize_batch_candidates()
274 with warnings.catch_warnings(record=True) as ws:
275 warnings.simplefilter("always", category=OptimizationWarning)
--> 276 batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy(
277 initial_conditions=batched_ics_, **scipy_kws
278 )
279 opt_warnings += ws
280 batch_candidates_list.append(batch_candidates_curr)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/generation/gen.py:203, in gen_candidates_scipy(initial_conditions, acquisition_function, lower_bounds, upper_bounds, inequality_constraints, equality_constraints, nonlinear_inequality_constraints, options, fixed_features)
200 def f(x):
201 return -acquisition_function(x)
--> 203 res = minimize(
204 fun=f_np_wrapper,
205 args=(f,),
206 x0=x0,
207 method=options.get("method", "SLSQP" if constraints else "L-BFGS-B"),
208 jac=True,
209 bounds=bounds,
210 constraints=constraints,
211 callback=options.get("callback", None),
212 options={k: v for k, v in options.items() if k not in ["method", "callback"]},
213 )
215 if "success" not in res.keys() or "status" not in res.keys():
216 with warnings.catch_warnings():
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_minimize.py:701, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
698 res = _minimize_cobyla(fun, x0, args, constraints, callback=callback,
699 **options)
700 elif meth == 'slsqp':
--> 701 res = _minimize_slsqp(fun, x0, args, jac, bounds,
702 constraints, callback=callback, **options)
703 elif meth == 'trust-constr':
704 res = _minimize_trustregion_constr(fun, x0, args, jac, hess, hessp,
705 bounds, constraints,
706 callback=callback, **options)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_slsqp_py.py:329, in _minimize_slsqp(func, x0, args, jac, bounds, constraints, maxiter, ftol, iprint, disp, eps, callback, finite_diff_rel_step, **unknown_options)
325 # Set the parameters that SLSQP will need
326 # meq, mieq: number of equality and inequality constraints
327 meq = sum(map(len, [atleast_1d(c['fun'](x, *c['args']))
328 for c in cons['eq']]))
--> 329 mieq = sum(map(len, [atleast_1d(c['fun'](x, *c['args']))
330 for c in cons['ineq']]))
331 # m = The total number of constraints
332 m = meq + mieq
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_slsqp_py.py:329, in <listcomp>(.0)
325 # Set the parameters that SLSQP will need
326 # meq, mieq: number of equality and inequality constraints
327 meq = sum(map(len, [atleast_1d(c['fun'](x, *c['args']))
328 for c in cons['eq']]))
--> 329 mieq = sum(map(len, [atleast_1d(c['fun'](x, *c['args']))
330 for c in cons['ineq']]))
331 # m = The total number of constraints
332 m = meq + mieq
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/parameter_constraints.py:359, in _make_f_and_grad_nonlinear_inequality_constraints.<locals>.f_obj(X)
357 if X_c is None or not np.array_equal(X_c, X):
358 cache["X"] = X.copy()
--> 359 cache["obj"], cache["grad"] = f_obj_and_grad(X)
360 return cache["obj"]
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/parameter_constraints.py:350, in _make_f_and_grad_nonlinear_inequality_constraints.<locals>.f_obj_and_grad(x)
349 def f_obj_and_grad(x):
--> 350 obj, grad = f_np_wrapper(x, f=nlc)
351 return obj, grad
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/generation/gen.py:174, in gen_candidates_scipy.<locals>.f_np_wrapper(x, f)
172 loss = f(X_fix).sum()
173 # compute gradient w.r.t. the inputs (does not accumulate in leaves)
--> 174 gradf = _arrayify(torch.autograd.grad(loss, X)[0].contiguous().view(-1))
175 if np.isnan(gradf).any():
176 msg = (
177 f"{np.isnan(gradf).sum()} elements of the {x.size} element "
178 "gradient array `gradf` are NaN. This often indicates numerical issues."
179 )
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py:276, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched)
274 return _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs_)
275 else:
--> 276 return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
277 t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,
278 allow_unused, accumulate_grad=False)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Any advice/help would be appreciated. And as always, even though I am hosting this code in my own "EasyBO" software, since this module is still licensed by Meta please do feel free to use whatever you'd like in your own updates if you find it useful.
@matthewcarbone I think you misunderstood how the nonlinear constraint function is supposed to be defined. Note that the constraint function is supposed to return a numerical value, where positive values indicate feasibility and negative values indicate infeasibility (this is consistent with the scipy convention). Your definition
def constraint(x):
return x[..., 1] - torch.abs(x[..., 0]) >= 0.5
returns a binary value. This is also why you see the error, since you can't differentiate through what essentially is a step function. Instead you'll need to define your constraint as
def constraint(x):
return x[..., 1] - torch.abs(x[..., 0]) - 0.5
so that constraint(x) >= 0
implies feasibility (as documented here).
In addition, you'll need to make the according change in your initial condition generator - instead of
where = torch.where(nonlinear_constraint(X_rnd))[0]
this line will need to become
where = torch.where(nonlinear_constraint(X_rnd) >= 0)[0]
I was able to run this (the SLSQP optimizer is noticeably slower than the default L-BFGS-B), and you'll end up with something like this which looks reasonable:
get_batch_initial_conditions_nonlinear_function
you could also just do the following:
from functools import partial
_gen_batch_initial_conditions_nonlinear = partial( gen_batch_initial_conditions_nonlinear, nonlinear_constraint=nonlinear_constraint, )
2. Looks like you're using the torch default data type (if not set otherwise this is `torch.float32`). I highly recommend using `torch.double` (aka `torch.float64`) since the linear systems that you need to solve in GP inference are often quite ill conditioned and so precision really matters to ensure the models are accurate and don't run into numerical issues. This is achieved simply by passing X/Y of the appropriate data type into the BoTorch model.
@Balandat this is fantastic help, thank you! clearly did misunderstand how to use the constraint function, I appreciate you clarifying.
Regarding the torch
default precision, it's buried in the notebook but I do set the default precision to torch.float64
π
torch.set_default_dtype(torch.float64)
I have in the past encountered those conditioning issues you mention.
Let me give this a try and I'll let you know if there's any further issues. Thanks again for your help!
Closing as resolved but feel free to reopen with any further questions or learnings!
π Bug
I am trying to use a non-linear inequality constraint with my Bayesian Optimization procedure. As a test, I'm simply starting with
$$ x_2 - |x_1| \geq 0. $$
Written as a constraint, this is just (I think):
However, when I try to do this using
optimize_acqf
, I get a RuntimeError:I am not looking to reproduce anything specifically, but there is no documentation about what this
ic_generator
actually is. A google search of"ic_generator" botorch
reveals two results, and both are from source code. It is not obvious how to use this feature.To reproduce
Simply use any call to
optimize_acqf
with the constraint above and without providing theic_generator
keyword argument.Expected Behavior
Unclear, it would be nice to have some details about how to provide a sensible
ic_generator
object so I can use this feature!System information
botorch
version 0.7.2gpytorch
version 1.9.0torch
version 1.12.0 System info: MacOS Ventura 13.0.1Additional context
Thanks in advance!