pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.01k stars 383 forks source link

[Bug] An exception should be raised when the training data has requires_grad=True #2253

Open rexxy-sasori opened 3 months ago

rexxy-sasori commented 3 months ago

šŸ› Bug

The train_X and train_Y that go into SingleTaskGP will lead to a failing fit_gpytorch_mll if they have require_grad=True, i.e. grad_fn is not None. The error goes away when the flag require_grad=False

To reproduce

Code snippet to reproduce

# Your code goes here
# Please make sure it does not require any external dependencies
import botorch, gpytorch, torch
from botorch.models import FixedNoiseGP, ModelListGP, SingleTaskGP
from gpytorch.mlls.sum_marginal_log_likelihood import SumMarginalLogLikelihood
from botorch import fit_gpytorch_mll

print(botorch.__version__)
print(gpytorch.__version__)
print(torch.__version__)

train_x = torch.randn(30, 6, requires_grad=True)

train_obj = 3*train_x + train_x**2
train_con = 4*train_x - train_x**3

print(train_obj.grad_fn)
print(train_con.grad_fn)

model_obj = SingleTaskGP(train_x, train_obj).to(train_x)
model_con = SingleTaskGP(train_x, train_con).to(train_x)

model = ModelListGP(model_obj, model_con)
mll = SumMarginalLogLikelihood(model.likelihood, model)

fit_gpytorch_mll(mll)

Stack trace/error message

RuntimeError                              Traceback (most recent call last)
Cell In[55], line 25
     22 model = ModelListGP(model_obj, model_con)
     23 mll = SumMarginalLogLikelihood(model.likelihood, model)
---> 25 fit_gpytorch_mll(mll)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:305, in _fit_list(mll, _, __, **kwargs)
    303 mll.train()
    304 for sub_mll in mll.mlls:
--> 305     fit_gpytorch_mll(sub_mll, **kwargs)
    307 return mll.eval() if not any(sub_mll.training for sub_mll in mll.mlls) else mll

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:105, in fit_gpytorch_mll(mll, closure, optimizer, closure_kwargs, optimizer_kwargs, **kwargs)
    102 if optimizer is not None:  # defer to per-method defaults
    103     kwargs["optimizer"] = optimizer
--> 105 return FitGPyTorchMLL(
    106     mll,
    107     type(mll.likelihood),
    108     type(mll.model),
    109     closure=closure,
    110     closure_kwargs=closure_kwargs,
    111     optimizer_kwargs=optimizer_kwargs,
    112     **kwargs,
    113 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/utils/dispatcher.py:93, in Dispatcher.__call__(self, *args, **kwargs)
     91 func = self.__getitem__(types=types)
     92 try:
---> 93     return func(*args, **kwargs)
     94 except MDNotImplementedError:
     95     # Traverses registered methods in order, yields whenever a match is found
     96     funcs = self.dispatch_iter(*types)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/fit.py:252, in _fit_fallback(mll, _, __, closure, optimizer, closure_kwargs, optimizer_kwargs, max_attempts, warning_handler, caught_exception_types, **ignore)
    250 with catch_warnings(record=True) as warning_list, debug(True):
    251     simplefilter("always", category=OptimizationWarning)
--> 252     optimizer(mll, closure=closure, **optimizer_kwargs)
    254 # Resolved warnings and determine whether or not to retry
    255 done = True

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/fit.py:92, in fit_gpytorch_mll_scipy(mll, parameters, bounds, closure, closure_kwargs, method, options, callback, timeout_sec)
     89 if closure_kwargs is not None:
     90     closure = partial(closure, **closure_kwargs)
---> 92 result = scipy_minimize(
     93     closure=closure,
     94     parameters=parameters,
     95     bounds=bounds,
     96     method=method,
     97     options=options,
     98     callback=callback,
     99     timeout_sec=timeout_sec,
    100 )
    101 if result.status != OptimizationStatus.SUCCESS:
    102     warn(
    103         f"`scipy_minimize` terminated with status {result.status}, displaying"
    104         f" original message from `scipy.optimize.minimize`: {result.message}",
    105         OptimizationWarning,
    106     )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/core.py:109, in scipy_minimize(closure, parameters, bounds, callback, x0, method, options, timeout_sec)
    101         result = OptimizationResult(
    102             step=next(call_counter),
    103             fval=float(wrapped_closure(x)[0]),
    104             status=OptimizationStatus.RUNNING,
    105             runtime=monotonic() - start_time,
    106         )
    107         return callback(parameters, result)  # pyre-ignore [29]
--> 109 raw = minimize_with_timeout(
    110     wrapped_closure,
    111     wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
    112     jac=True,
    113     bounds=bounds_np,
    114     method=method,
    115     options=options,
    116     callback=wrapped_callback,
    117     timeout_sec=timeout_sec,
    118 )
    120 # Post-processing and outcome handling
    121 wrapped_closure.state = asarray(raw.x)  # set parameter state to optimal values

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/timeout.py:80, in minimize_with_timeout(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options, timeout_sec)
     77     wrapped_callback = callback
     79 try:
---> 80     return optimize.minimize(
     81         fun=fun,
     82         x0=x0,
     83         args=args,
     84         method=method,
     85         jac=jac,
     86         hess=hess,
     87         hessp=hessp,
     88         bounds=bounds,
     89         constraints=constraints,
     90         tol=tol,
     91         callback=wrapped_callback,
     92         options=options,
     93     )
     94 except OptimizationTimeoutError as e:
     95     msg = f"Optimization timed out after {e.runtime} seconds."

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_minimize.py:710, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
    707     res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
    708                              **options)
    709 elif meth == 'l-bfgs-b':
--> 710     res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
    711                            callback=callback, **options)
    712 elif meth == 'tnc':
    713     res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
    714                         **options)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_lbfgsb_py.py:365, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
    359 task_str = task.tobytes()
    360 if task_str.startswith(b'FG'):
    361     # The minimization routine wants f and g at the current x.
    362     # Note that interruptions due to maxfun are postponed
    363     # until the completion of the current minimization iteration.
    364     # Overwrite f and g:
--> 365     f, g = func_and_grad(x)
    366 elif task_str.startswith(b'NEW_X'):
    367     # new iteration
    368     n_iterations += 1

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:285, in ScalarFunction.fun_and_grad(self, x)
    283 if not np.array_equal(x, self.x):
    284     self._update_x_impl(x)
--> 285 self._update_fun()
    286 self._update_grad()
    287 return self.f, self.g

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
    249 def _update_fun(self):
    250     if not self.f_updated:
--> 251         self._update_fun_impl()
    252         self.f_updated = True

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
    154 def update_fun():
--> 155     self.f = fun_wrapped(self.x)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
    133 self.nfev += 1
    134 # Send a copy because the user may overwrite it.
    135 # Overwriting results in undefined behaviour because
    136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
    138 # Make sure the function returns a true scalar
    139 if not np.isscalar(fx):

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_optimize.py:77, in MemoizeJac.__call__(self, x, *args)
     75 def __call__(self, x, *args):
     76     """ returns the function value """
---> 77     self._compute_if_needed(x, *args)
     78     return self._value

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/scipy/optimize/_optimize.py:71, in MemoizeJac._compute_if_needed(self, x, *args)
     69 if not np.all(x == self.x) or self._value is None or self.jac is None:
     70     self.x = np.asarray(x).copy()
---> 71     fg = self.fun(x, *args)
     72     self.jac = fg[1]
     73     self._value = fg[0]

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:160, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    158         index += size
    159 except RuntimeError as e:
--> 160     value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
    162 return value, grads

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/utils/common.py:52, in _handle_numerical_errors(error, x, dtype)
     50     _dtype = x.dtype if dtype is None else dtype
     51     return np.full((), "nan", dtype=_dtype), np.full_like(x, "nan", dtype=_dtype)
---> 52 raise error

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:150, in NdarrayOptimizationClosure.__call__(self, state, **kwargs)
    147     self.state = state
    149 try:
--> 150     value_tensor, grad_tensors = self.closure(**kwargs)
    151     value = self.as_array(value_tensor)
    152     grads = self._get_gradient_ndarray(fill_value=self.fill_value)

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/botorch/optim/closures/core.py:66, in ForwardBackwardClosure.__call__(self, **kwargs)
     64 values = self.forward(**kwargs)
     65 value = values if self.reducer is None else self.reducer(values)
---> 66 self.backward(value)
     68 grads = tuple(param.grad for param in self.parameters.values())
     69 if self.callback:

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/_tensor.py:522, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    512 if has_torch_function_unary(self):
    513     return handle_torch_function(
    514         Tensor.backward,
    515         (self,),
   (...)
    520         inputs=inputs,
    521     )
--> 522 torch.autograd.backward(
    523     self, gradient, retain_graph, create_graph, inputs=inputs
    524 )

File /opt/anaconda3/envs/idc/lib/python3.10/site-packages/torch/autograd/__init__.py:266, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    261     retain_graph = create_graph
    263 # The reason we repeat the same comment below is that
    264 # some Python versions print out the first line of a multi-line function
    265 # calls in the traceback and some print out the last line
--> 266 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    267     tensors,
    268     grad_tensors_,
    269     retain_graph,
    270     create_graph,
    271     inputs,
    272     allow_unreachable=True,
    273     accumulate_grad=True,
    274 )

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Expected Behavior

An error message is thrown:

"Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward"

System information

Please complete the following information:

Additional context

Balandat commented 3 months ago

Is there a particular use case you have for training data with requires_grad=True? Or is this an issue that you ran into without deliberately setting this? In general it's not clear what it would mean to fit the model if the training data itself were a parameter (you'd get a perfect fit all the time...). One could compute the gradient of the model parameters w.r.t. the training data at the optimum (MAP maximizer), but that's a different thing.

esantorella commented 3 months ago

Thanks for reporting this issue. I'm not surprised this fails, because BoTorch figures out which tensors are parameters that need to be optimized by looking at which have requires_grad=True. I second Max's question about the use case, since I'm not sure whether it would make sense to support this. Would it work to (perhaps temporarily) detach the input data?

rexxy-sasori commented 3 months ago

Sorry, perhaps I should give a little bit context. This is actually an issue I ran into without deliberately setting requires_grad. The example here is just a demonstration for anyone here to reproduce the bug. I want to use risk-averse BO for model predictive control, in which I first built a model in PyTorch that maps my control variable to my objective. However, the fit_gpytorch method always failed until I figured out that the issue went away until I used torch.no_grad()

Balandat commented 3 months ago

However, the fit_gpytorch method always failed until I figured out that the issue went away until I used with torch.no_grad()

Just to make sure there is no confusion here, you are not putting the fit_gpytorch_mll() call into a no_grad() context, right? Just making sure the inputs to the GP model don't require gradients, i.e. do whatever prediction you do on your model in a no_grad() context.

It may make sense on our end to explicitly check whether the training data requires grad when calling fit_gpytorch_mll to emit a more informative error message.

rexxy-sasori commented 3 months ago

No I am not putting the fit_gpytorch_mll() call into a no_grad() context.

Yes, I agree with your suggestion to explicitly check whether the training data requires grad when calling fit_gpytorch_mll to emit a more informative error message.