pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.09k stars 401 forks source link

[Bug] "RuntimeError: Trying to backward through the graph a second time" error in `fit_gpytorch_mll` #1386

Closed sambitmishra98 closed 1 year ago

sambitmishra98 commented 2 years ago

I am using HeteroskedasticSingleTaskGP model and qNEI acquisition function to optimise some parameters in a Navier Stokes equations solver. For a certain simulation case I can run the solver and get results for runtime and variance in runtime.

The data looks like the following, with the last two rows giving the mean and standard deviation of cost function for each set of 8 parameters.

cycle_1 | cycle_2 | cycle_3 | cycle_4 | pseudo-dt-fact_1 | pseudo-dt-fact_2 | pseudo-dt-fact_3 | pseudo-dt-fact_4 | cost-std | cost-avg -- | -- | -- | -- | -- | -- | -- | -- | -- | -- 1.0 | 2.0 | 2.0 | 2.0 | 1.7 | 1.7 | 1.7 | 1.7 | 0.26462015818723367 | 2.8067303993995667 10.0 | 8.78 | 9.945 | 5.88 | 2.5 | 2.5 | 2.5 | 2.5 | 1.3549891603999018 | 4.421466454399979 1.0 | 1.002 | 1.0 | 1.0 | 1.5 | 1.5 | 1.5 | 1.5 | 0.051330104700350467 | 3.054119216222211 0.0 | 0.4717 | 0.313 | 0.8555 | 1.868 | 1.732 | 1.599 | 1.63 | 0.15447312978140554 | 3.602360449999954 1.723 | 2.643 | 2.758 | 2.365 | 1.485 | 1.584 | 1.681 | 1.658 | 0.08864621313289052 | 2.64573476366685 1.022 | 3.791 | 2.932 | 2.729 | 1.676 | 1.527 | 1.664 | 1.572 | 0.1258471844610435 | 2.6081113264440825 0.699 | 3.441 | 2.357 | 2.33 | 1.57 | 1.487 | 1.858 | 1.641 | 0.06562802205222522 | 2.6927376279994255 0.664 | 3.682 | 2.477 | 3.076 | 1.585 | 1.533 | 1.609 | 1.772 | 0.058379624106582324 | 2.6839997258887 0.8994 | 3.89 | 2.393 | 3.105 | 1.54 | 1.692 | 1.716 | 1.562 | 0.11649495254246003 | 2.697327930110987 0.469 | 2.469 | 3.734 | 3.03 | 1.597 | 1.566 | 1.712 | 1.624 | 0.05940422227130007 | 2.666242495443738 1.121 | 3.93 | 3.615 | 1.771 | 1.603 | 1.611 | 1.695 | 1.681 | 0.1051393450731216 | 2.5451599623340573

🐛 Bug

So the issue I am facing happens at random times.

Traceback (most recent call last):
  File "/home/sambit98/pyfr-installation/pyfr-venv/bin/pyfr", line 33, in <module>
    sys.exit(load_entry_point('pyfr', 'console_scripts', 'pyfr')())
  File "/home/sambit98/pyfr-installation/PyFR/pyfr/__main__.py", line 118, in main
    args.process(args)
  File "/home/sambit98/pyfr-installation/PyFR/pyfr/__main__.py", line 270, in process_restart
    _process_common(args, mesh, soln, cfg)
  File "/home/sambit98/pyfr-installation/PyFR/pyfr/__main__.py", line 247, in _process_common
    solver.run()
  File "/home/sambit98/pyfr-installation/PyFR/pyfr/integrators/base.py", line 122, in run
    self.advance_to(t)
  File "/home/sambit98/pyfr-installation/PyFR/pyfr/integrators/dual/phys/controllers.py", line 56, in advance_to
    self._accept_step(self.pseudointegrator._idxcurr)
  File "/home/sambit98/pyfr-installation/PyFR/pyfr/integrators/dual/phys/controllers.py", line 35, in _accept_step
    csh(self)
  File "/home/sambit98/pyfr-installation/PyFR/pyfr/plugins/optimisation.py", line 99, in __call__
    self.cand_curr = self.cc.suggest_candidate(self.hc.opt_hist, self.oc.bounds, self.hc.gear)
  File "/home/sambit98/pyfr-installation/PyFR/pyfr/plugins/optimisation.py", line 431, in suggest_candidate
    case 'explore'  : cand_list = explorer(opt_hist, bounds).exploreqNEI()
  File "/home/sambit98/pyfr-installation/PyFR/optimiser/optimiser/bayesopt/exploration.py", line 20, in __init__
    super().__init__(dataset, bounds)
  File "/home/sambit98/pyfr-installation/PyFR/optimiser/optimiser/optimisables_handler.py", line 45, in __init__
    fit_model(mll)
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/botorch/fit.py", line 130, in fit_gpytorch_model
    mll, _ = optimizer(mll, track_iterations=False, **kwargs)
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/botorch/optim/fit.py", line 241, in fit_gpytorch_scipy
    res = minimize(
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/scipy/optimize/_minimize.py", line 699, in minimize
    res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/scipy/optimize/_lbfgsb_py.py", line 362, in _minimize_lbfgsb
    f, g = func_and_grad(x)
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 285, in fun_and_grad
    self._update_fun()
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 251, in _update_fun
    self._update_fun_impl()
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 155, in update_fun
    self.f = fun_wrapped(self.x)
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/scipy/optimize/_differentiable_functions.py", line 137, in fun_wrapped
    fx = fun(np.copy(x), *args)
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 76, in __call__
    self._compute_if_needed(x, *args)
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/scipy/optimize/_optimize.py", line 70, in _compute_if_needed
    fg = self.fun(x, *args)
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/botorch/optim/utils.py", line 218, in _scipy_objective_and_grad
    loss.backward()
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/sambit98/pyfr-installation/pyfr-venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

System information

Probably adding an argument retain_graph = True will solve the issue. I want to add this to an as high a level as possible. Let me know where I could add that arguement. Also, if its necessary to replicate this issue, let me know where to add torch.random.manual_seed(0). I am not yet able to replicate this issue at will.

Balandat commented 2 years ago

Probably adding an argument retain_graph = True will solve the issue.

Hmm I'm not sure we want to do that, this may just end up masking a bigger issue (or not work at all).

I'd like to first understand why this is happening in the current code (I don't think it should). Would you happen to have a reproducible example for this (ok if this only occasionally raises this error rather than deterministically)?

sambitmishra98 commented 2 years ago

The following is the optimisation part of my code. I am not able to repeat the same error I got, but I have used the same dataset at which point I had got the error.

from botorch.fit              import fit_gpytorch_model as fit_model
import numpy as np

from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.optim.optimize   import optimize_acqf

class optim_handler:
    """
        Handle for converging to a model based on given dataset and bounds. 

    """
    import torch

    from botorch.models           import SingleTaskGP, HeteroskedasticSingleTaskGP
    from botorch.models.transforms import Normalize, Standardize 
    from gpytorch.mlls            import ExactMarginalLogLikelihood

    torch.random.manual_seed(0)

    def __init__(self, dataset:np.ndarray, bounds:np.ndarray):

        torch_kwargs = {'dtype': self.torch.float32, 'device': self.torch.device("cpu")}

        self.X     = self.torch.tensor(dataset[:,  :-2]   , **torch_kwargs)
        self.Y     = self.torch.tensor(dataset[:,-2:-1]   , **torch_kwargs)
        self.Y_var = self.torch.tensor(dataset[:,-1:  ]**2, **torch_kwargs)
        self.bnds  = self.torch.tensor(bounds             , **torch_kwargs)

        model_kwargs = {'train_X': self.X, 
                        'train_Y': self.Y,
                        'outcome_transform': self.Standardize(m=1),
                        'input_transform': self.Normalize(d=self.X.shape[-1], bounds = self.bnds)
                        }

        #self.model = self.SingleTaskGP(**model_kwargs)
        self.model = self.HeteroskedasticSingleTaskGP(train_Yvar = self.Y_var, **model_kwargs)

        mll = self.ExactMarginalLogLikelihood(likelihood=self.model.likelihood, 
                                              model=self.model, 
                               )
        # set mll and all submodules to the specified dtype and device
        mll = mll.to(self.X)

        fit_model(mll)

    def show_opt_output(self) -> None:
        """ Print the end result of optimisation
        """
        print("----------------------------------------------")
        print(f"{len(self.X_cand)} candidate(s) generated.") 

        if len(self.X_cand) == 1 :
            print(f"Next suggested candidate is {self.X_cand.tolist()}")
            Y_cand = self.mean_from_model()
            print(f"Next simulation's cost may be around {Y_cand}.")

    def conf_from_model(self):

        posterior = self.model.posterior(self.X_cand)
        lower, upper = posterior.mvn.confidence_region()
        l = self.unstandardise_Y(lower)
        m = self.unstandardise_Y(self.model(self.X_cand)[0].mean)
        u = self.unstandardise_Y(upper)

        return l, m, u

    def mean_from_model(self):
        return self.unstandardise_Y(self.model(self.X_cand)[0].mean)

    def std_from_model(self):
        return self.unstandardise_Y(self.model(self.X_cand)[0].std)

    def calc_from_model(self, X_cand:list):

        X_cand_tensor = self.torch.tensor(X_cand)

        if X_cand_tensor.dim() == 1:
            X_cand_tensor = X_cand_tensor.reshape(1,-1)   

        posterior = self.model.posterior(X_cand_tensor)
        mn = self.unstandardise_Y(posterior.mean)
        print(f'Given: {X_cand_tensor}, posterior mean: {mn}')
        return mn

    def compare_wallt(self, opt_list:list) -> None:
        """Compare the actual simulation wall-time with prediction

        Args:
            wallt_a (list): Actual wall-time
        """

        print('_____________________________________________________________________________________')
        print(f"{'Candidate':<30}   {'Predicted':>11}   {'Actual':>11}   {'Error':>11}")

        err_cands  = [opt_val[:-1] for opt_val in opt_list]
        err_wtimes = [opt_val[-1]  for opt_val in opt_list]

        Err = []

        for i, _ in enumerate(opt_list) :
            X_cand = self.torch.tensor(err_cands[i:i+1])

            posterior = self.model.posterior(X_cand)
            #Y_cand_st = posterior.mean.tolist()[0][0]
            #Y_cand = self.unstandardise_Y(-Y_cand_st)

            Y_cand = self.unstandardise_Y(self.model(X_cand).mean)
            Err.append(abs((err_wtimes[i] - Y_cand)/err_wtimes[i]))
            print(*[f"{ii}" for ii in err_cands[i]], 
                  f"--> {Y_cand:>10.3f}  {err_wtimes[i]:>10.3f}  {Err[i]*100:>10.3f} %")
        print('_____________________________________________________________________________________')

        return Err

    def calc_wallt(self, opt_list:list) -> None:
        """Compare the actual simulation wall-time with prediction

        Args:
            wallt_a (list): Actual wall-time
        """

        print(f"{'Candidate':<30}{'Predicted':>11}{'Actual':>11}{'Error':>11}")

        err_cands  = [opt_val[:-1] for opt_val in opt_list]

        Err = []

        for i, _ in enumerate(opt_list) :
            X_cand = self.torch.tensor(err_cands[i:i+1])
            Y_cand = self.unstandardise_Y(self.model(X_cand).mean)
        return 

class explorer(optim_handler):
    """
        Explores the solution space with Bayesian Optimisation.
        This would be expensive, but worth it in the long run.
    """

    q = 1
    num_fantasies = 128
    num_restarts  = 20
    raw_samples   = 512

    def __init__(self, dataset:np.ndarray, bounds:np.ndarray) -> None:
        super().__init__(dataset, bounds)

    def __optimise_now(self):
        self.X_cand, _ = optimize_acqf(
            acq_function = self.__acquisition_function, 
            bounds       = self.bnds,
            q            = self.q,
            num_restarts = self.num_restarts,
            raw_samples  = self.raw_samples,
        )
        return self.X_cand.tolist()[0]

    def exploreqKG(self):
        from botorch.acquisition import qKnowledgeGradient as qKG
        self.__acquisition_function = qKG(self.model, 
                               num_fantasies = self.num_fantasies,
                               posterior_transform = ScalarizedPosteriorTransform(weights=self.torch.tensor([-1.0])),
                               )
        return self.__optimise_now()

    def exploreqNEI(self):
        from botorch.acquisition import qNoisyExpectedImprovement as qNEI
        #print("Exploration starts with ExpectedImprovement Bayesian Optimisation")    

        self.__acquisition_function = qNEI(self.model, 
                                self.X,
                                posterior_transform = ScalarizedPosteriorTransform(weights=self.torch.tensor([-1.0])),
                               )
        return self.__optimise_now()

if __name__ == "__main__":

    opt_hist = np.array([

[1.0            ,2.0        ,2.0        ,2.0        ,1.7         ,1.7       ,1.7        ,1.7       , 2.879927866298021   ,0.2809817522680068],
[8.54       ,7.734  ,7.688  ,8.98   ,2.5         ,2.367 , 2.5   , 2.5   , 14.40394745742742  ,14.145361129953962],
[1.568      ,1.002  ,1.0        ,1.0        ,1.673   ,1.528 , 1.511 , 1.516  ,3.1109395022229163     ,0.20819094898083226],
[0.3188     ,0.615  ,0.993  ,0.7627 ,1.55    ,1.555 , 1.638 , 1.796 , 3.574670920666702  ,0.06839043006985578],
[0.002207   ,1.738  ,0.654  ,0.3003 ,1.69    ,1.728 , 1.66  , 1.537 , 3.800978265444428  ,0.18756724016407256],
[0.673      ,1.827  ,2.111  ,2.404  ,1.523   ,1.604 , 1.586 , 1.547  ,2.9147677558899483     ,0.1267367419949604],
[0.4775     ,0.82   ,2.05   ,2.266  ,1.708   ,1.671 , 1.478 , 1.674 , 3.113726179888747  ,0.1322018802377944],
[0.54       ,0.8223 ,2.041  ,2.223  ,1.715   ,1.521 , 1.712 , 1.565  ,3.1511588787765907     ,0.08548508842886206],
[0.726      ,2.523  ,2.232  ,1.241  ,1.685   ,1.5205    , 1.547 , 1.669 , 2.743173125445335  ,0.10465200160196104],
[1.588      ,1.214  ,2.984  ,0.7705 ,1.611   ,1.627 , 1.604 , 1.614 , 2.817405173333908  ,0.09363403187314359],
[2.2            ,1.842  ,1.893  ,2.22   ,1.624   ,1.552 , 1.581 , 1.687 , 2.825294692666729  ,0.1080449707791949],
[1.743      ,2.553  ,2.78   ,2.047  ,1.719   ,1.635 , 1.555 , 1.504 , 2.689015759000338  ,0.09333393640729781],
[1.282      ,2.0        ,2.219  ,1.649  ,1.662   ,1.61  , 1.584 , 1.591  ,2.8422798810000436     ,0.12392669332300357],
[1.327      ,2.771  ,4.836  ,2.81   ,1.661   ,1.559 , 1.585 , 1.706 , 2.583281930112005  ,0.09948788147320681],
[1.863      ,1.97   ,4.64   ,1.8955 ,1.637   ,1.434 , 1.402 , 1.513 , 2.635251861109686  ,0.11428428045814855],
[2.545      ,2.244  ,4.85   ,1.054  ,1.865   ,1.493 , 1.522 , 1.694 , 3.004910332111548  ,0.1253705825953301],
[2.2            ,3.71   ,4.652  ,1.329  ,1.573   ,1.434 , 1.651 , 1.512 , 2.793829325777349  ,0.06383065736130515],
[2.066      ,3.92   ,4.58   ,1.331  ,1.563   ,1.629 , 1.383 , 1.659  ,2.7765917263330064     ,0.09696595158940072],
[0.293      ,3.205  ,5.65   ,1.374  ,1.67    ,1.58  , 1.509 , 1.513 , 2.712518582111645  ,0.12223720379684369],
[2.592      ,2.057  ,6.004  ,2.395  ,1.578   ,1.671 , 1.561 , 1.506  ,2.6304051203320946     ,0.09928534498998075],
[1.804      ,2.477  ,4.78   ,1.795  ,1.619   ,1.566 , 1.53  , 1.591  ,2.6605224885558223     ,0.0983155991975453],
[1.368      ,4.684  ,4.95   ,4.582  ,1.656   ,1.52  , 1.42  , 1.438  ,2.5523275455565857     ,0.12325918710545654],
[0.3103     ,5.72   ,3.27   ,3.543  ,1.783   ,1.434 , 1.415 , 1.673 , 2.627355733000311  ,0.10942659261654909],
[1.116      ,5.918  ,4.59   ,4.477  ,1.681   ,1.731 , 1.656 , 1.617  ,2.5916609974451097     ,0.08442744570907369],
[0.0            ,0.3818 ,9.266  ,10.0   ,1.486   ,2.291 , 1.377 , 1.41  , 2.774310617333362  ,0.12844921262269227],
[0.877      ,2.209  ,7.035  ,7.254  ,1.554   ,1.968 , 1.486 , 1.481 , 2.801150487999823  ,0.15607574338050392],
[1.684      ,0.0484 ,10.0   ,4.684  ,1.486   ,2.31  , 1.377 , 1.41   ,2.7725512354461697     ,0.1336467024188023],
[1.4795     ,0.0484 ,10.0   ,7.07   ,1.804   ,2.074 , 1.377 , 1.406,     2.61339530333377    ,0.08563167833530538],
[0.04694        ,3.87   ,10.0   ,6.367  ,1.755   ,2.32  , 1.377 , 1.406  ,2.6285033493343994     ,0.09090215650380991],
[2.92       ,2.652  ,8.305  ,7.906  ,1.703   ,2.389 , 1.344 , 1.382  ,2.6655712517769845     ,0.07955921073587495],
[1.462      ,1.3        ,10.164 ,7.383  ,1.693   ,2.371 , 1.684 , 1.363  ,2.6494278038868893     ,0.11048935607967407]        ,
    ])

    bounds = np.array([[ 1.,  1.,  1.,  1., 1.5, 1.5, 1.5, 1.5],
                       [10., 10., 10., 10., 2.5, 2.5, 2.5, 2.5]])

    cand_as_list = explorer(opt_hist, bounds).exploreqNEI()
    print(cand_as_list)
Balandat commented 2 years ago

I stress tested this (same versions, but on a M1 mac) by running this 100 times and I didn't get a single failure. So this is either a super rare event or doesn't happen on all platforms. If you stress test this do you see any failures?

sambitmishra98 commented 2 years ago

I just checked it 100 times as well, I didn't get any problems either. Is there any other way to go about this? Is there anything I can code to print so that if it does happen again I could figure out where the issue is?

Balandat commented 2 years ago

I don't know if that can easily be done preemptively. What you could try is use the autograd anomaly detection mode:

from torch import autograd

with autograd.detect_anomaly():
    # your code goes here
sambitmishra98 commented 2 years ago

I may be facing the problem because I was using qNEI acquisition function candidates to initialise for the model. When I just switched to qKG, the issue did not come up anymore. KG is known to be more explorative than EI, so this should be good in the long run too. Since I am unable to isolate the issue properly and have found a reasonable better way around it, I will close this issue. If this issue comes up again, I will try using autograd.detect_anomaly() like you mentioned and get to the bottom of it.

sambitmishra98 commented 2 years ago

Since this repeated a lot more after closing the issue, I tried to replicate the issue and was successful. This gives the error I am talking about everytime I run it.

from botorch.fit              import fit_gpytorch_model as fit_model
import numpy as np

from botorch.acquisition.objective import ScalarizedPosteriorTransform
from botorch.optim.optimize   import optimize_acqf

from torch import autograd

class optim_handler:
    """
        Handle for converging to a model based on given dataset and bounds. 

    """
    import torch

    from botorch.models           import SingleTaskGP, HeteroskedasticSingleTaskGP
    from botorch.models.transforms import Normalize, Standardize 
    from gpytorch.mlls            import ExactMarginalLogLikelihood

    torch.random.manual_seed(0)
    torch_kwargs = {'dtype': torch.float32, 'device': torch.device("cpu")}

    def __init__(self, dataset:np.ndarray, bounds:np.ndarray):

        self.X     = self.torch.tensor(dataset[:,  :-2]   , **self.torch_kwargs)
        self.Y     = self.torch.tensor(dataset[:,-2:-1]   , **self.torch_kwargs)
        self.Y_std = self.torch.tensor(dataset[:,-1:  ]   , **self.torch_kwargs)
        self.Y_var = self.Y_std**2
        self.Y_err = 1.96 * self.Y_std/self.Y
        self.bnds  = self.torch.tensor(bounds             , **self.torch_kwargs)

        self.model_kwargs ={'train_X': self.X, 
                            'train_Y': self.Y,
                            'outcome_transform': self.Standardize(m=1),
                            'input_transform': self.Normalize(d=self.X.shape[-1], bounds = self.bnds)
                            }

    def show_opt_output(self) -> None:
        """ Print the end result of optimisation
        """
        print("----------------------------------------------")
        print(f"{len(self.X_cand)} candidate(s) generated.") 

        if len(self.X_cand) == 1 :
            print(f"Next suggested candidate is {self.X_cand.tolist()}")
            Y_cand = self.mean_from_model()
            print(f"Next simulation's cost may be around {Y_cand}.")

    def conf_from_model(self):

        posterior = self.model.posterior(self.X_cand)
        lower, upper = posterior.mvn.confidence_region()
        l = self.unstandardise_Y(lower)
        m = self.unstandardise_Y(self.model(self.X_cand)[0].mean)
        u = self.unstandardise_Y(upper)

        return l, m, u

    def mean_from_model(self):
        return self.unstandardise_Y(self.model(self.X_cand)[0].mean)

    def std_from_model(self):
        return self.unstandardise_Y(self.model(self.X_cand)[0].std)

    def calc_from_model(self, X_cand:list):

        X_cand_tensor = self.torch.tensor(X_cand)

        if X_cand_tensor.dim() == 1:
            X_cand_tensor = X_cand_tensor.reshape(1,-1)   

        posterior = self.model.posterior(X_cand_tensor)
        mn = self.unstandardise_Y(posterior.mean)
        print(f'Given: {X_cand_tensor}, posterior mean: {mn}')
        return mn

    def compare_wallt(self, opt_list:list) -> None:
        """Compare the actual simulation wall-time with prediction

        Args:
            wallt_a (list): Actual wall-time
        """

        print('_____________________________________________________________________________________')
        print(f"{'Candidate':<30}   {'Predicted':>11}   {'Actual':>11}   {'Error':>11}")

        err_cands  = [opt_val[:-1] for opt_val in opt_list]
        err_wtimes = [opt_val[-1]  for opt_val in opt_list]

        Err = []

        for i, _ in enumerate(opt_list) :
            X_cand = self.torch.tensor(err_cands[i:i+1])

            posterior = self.model.posterior(X_cand)
            #Y_cand_st = posterior.mean.tolist()[0][0]
            #Y_cand = self.unstandardise_Y(-Y_cand_st)

            Y_cand = self.unstandardise_Y(self.model(X_cand).mean)
            Err.append(abs((err_wtimes[i] - Y_cand)/err_wtimes[i]))
            print(*[f"{ii}" for ii in err_cands[i]], 
                  f"--> {Y_cand:>10.3f}  {err_wtimes[i]:>10.3f}  {Err[i]*100:>10.3f} %")
        print('_____________________________________________________________________________________')

        return Err

    def calc_wallt(self, opt_list:list) -> None:
        """Compare the actual simulation wall-time with prediction

        Args:
            wallt_a (list): Actual wall-time
        """

        print(f"{'Candidate':<30}{'Predicted':>11}{'Actual':>11}{'Error':>11}")

        err_cands  = [opt_val[:-1] for opt_val in opt_list]

        Err = []

        for i, _ in enumerate(opt_list) :
            X_cand = self.torch.tensor(err_cands[i:i+1])
            Y_cand = self.unstandardise_Y(self.model(X_cand).mean)
        return 

class optimiser(optim_handler):
    """
        Explores the solution space with Bayesian Optimisation.
        This would be expensive, but worth it in the long run.
    """

    def __init__(self, dataset:np.ndarray, bounds:np.ndarray) -> None:
        super().__init__(dataset, bounds)

    def __optimise_now(self, raw_samples = 512, num_restarts = 10):
        self.X_cand, _ = optimize_acqf(
            acq_function = self.__acquisition_function, 
            bounds       = self.bnds,
            q            = 1,
            num_restarts = num_restarts,
            raw_samples  = raw_samples,
        )
        return self.X_cand.tolist()[0]

    def exploreqKG(self):
        from botorch.acquisition import qKnowledgeGradient as qKG

        model1 = self.SingleTaskGP(**self.model_kwargs)
        mll1 = self.ExactMarginalLogLikelihood(likelihood=model1.likelihood, 
                                              model=model1, 
                                              )
        mll1 = mll1.to(self.X)
        fit_model(mll1)

        self.__acquisition_function = qKG(model1, num_fantasies = 64,
                               posterior_transform = ScalarizedPosteriorTransform(weights=self.torch.tensor([-1.0])),)
        return self.__optimise_now(raw_samples = 512, num_restarts = 10)

    def exploreqNEI(self):
        from botorch.acquisition import qNoisyExpectedImprovement as qNEI
        #print("Exploration starts with ExpectedImprovement Bayesian Optimisation")    
        model2 = self.HeteroskedasticSingleTaskGP(train_Yvar = self.Y_var, **self.model_kwargs)
        mll2 = self.ExactMarginalLogLikelihood(likelihood=model2.likelihood, model=model2, )
        mll2 = mll2.to(self.X)
        fit_model(mll2)

        self.__acquisition_function = qNEI(model2, self.X,
            posterior_transform = ScalarizedPosteriorTransform(weights=self.torch.tensor([-1.0])),)
        return self.__optimise_now(raw_samples = 1024, num_restarts = 100)

    def evaluatePM(self):
        """Get the minimum from posterior mean."""

        from botorch.acquisition import PosteriorMean as PM

        model3 = self.HeteroskedasticSingleTaskGP(train_Yvar = self.Y_var, **self.model_kwargs)
        mll3 = self.ExactMarginalLogLikelihood(likelihood=model3.likelihood, model=model3, )
        mll3 = mll3.to(self.X)
        fit_model(mll3)

        self.__acquisition_function = PM(model3, maximize = False)

        return self.__optimise_now(raw_samples = 1024, num_restarts = 100)

    def evaluateLOOCV(self):
        """Get quality of model"""

        from botorch.cross_validation import gen_loo_cv_folds, batch_cross_validation

        print(f"{np.shape(self.X)     = }")
        print(f"{np.shape(self.Y)     = }")
        print(f"{np.shape(self.Y_var) = }")

        cv_folds = gen_loo_cv_folds(train_X    = self.X, 
                                    train_Y    = self.Y, 
                                    train_Yvar = self.Y_var)

        cv_results = batch_cross_validation( 
                    model_cls = self.HeteroskedasticSingleTaskGP, 
                    mll_cls   = self.ExactMarginalLogLikelihood, 
                    cv_folds  = cv_folds,)

        posterior = cv_results.posterior
        mean = posterior.mean
        cv_error = ((cv_folds.test_Y.squeeze() - mean.squeeze()) ** 2).mean()
        print(f"Cross-validation error: {cv_error : 4.2}")

        lower, upper = posterior.mvn.confidence_region()

        from matplotlib import pyplot as plt

        _, axes = plt.subplots(1, 1, figsize=(6, 4))
        plt.plot([0, 4], [0, max(self.Y)], 'k', label="true objective", linewidth=2)

        axes.set_xlabel("Actual")
        axes.set_ylabel("Predicted")

        axes.errorbar(
            x=cv_folds.test_Y.numpy().flatten(), 
            y=mean.numpy().flatten(), 
            xerr=self.Y_err.numpy().flatten(),
            yerr=((upper-lower)/2).numpy().flatten(),
            fmt='*'
        )

        plt.savefig(f"loocv_plots/LOOCV-{np.shape(self.X)[0]}.png")

        try:
            return cv_error[0]
        except:
            return 0

if __name__ == "__main__":

    opt_hist = np.array([       
[1.0         ,5.32   ,8.62   ,1.0    ,3.469931521995022  ,0.15084671996718488],
[1.01    ,3.508  ,1.056  ,4.44   ,3.0313060119951842     ,0.142843038037407],
[1.056   ,1.375  ,1.211  ,5.88   ,3.284004883604939  ,0.08974095246203823],
[1.03    ,5.113  ,2.611  ,4.758  ,2.7276665300043534     ,0.06864569117860327],
[1.14    ,6.883  ,1.689  ,4.164  ,2.653645221333136  ,0.08835833306039219],
[0.9976  ,7.305  ,1.924  ,6.176  ,2.835054016664799  ,0.12341695957709688],
[2.512   ,6.984  ,2.723  ,3.625  ,2.401715438000489  ,0.0501839288853536],
[3.645   ,6.582  ,2.12   ,3.916  ,2.4657301871380435     ,0.13502478920112984],
[3.227   ,7.69   ,2.416  ,1.956  ,2.6161442657238303     ,0.134975290605758],
[3.455   ,8.05   ,3.248  ,4.383  ,2.4362339208601043     ,0.0763049692881096],
[3.627   ,6.414  ,3.662  ,3.818  ,2.3531628668632556     ,0.08959766019903548],
[3.309   ,6.297  ,3.799  ,4.758  ,2.4697823554285736     ,0.09684451240582466],
[5.383   ,6.723  ,3.621  ,3.646  ,2.4180864102454507     ,0.0949956524836544],
[3.098   ,7.285  ,4.28   ,3.402  ,2.417899090261926  ,0.07452566623072517],
[3.645   ,5.742  ,3.56   ,3.17   ,2.3491518319982183     ,0.06494949029839713],
[4.41    ,5.254  ,4.117  ,3.371  ,2.479834037103703  ,0.059859927823423295],
[2.334   ,6.05   ,3.65   ,3.01   ,2.4447961137807903     ,0.0952101333908594],
[4.15    ,6.938  ,3.514  ,2.81   ,2.425653966004029  ,0.07915923263163083],
[4.254   ,5.53   ,2.812  ,3.389  ,2.4762093682031265     ,0.1247154004151244],
[10.0    ,10.0   ,5.09   ,6.2    ,26.550974283993128     ,4.411758691127266],
[0.787   ,10.0   ,10.0   ,8.08   ,2.957984932732705  ,0.43974607928385256],
[0.787   ,1.191  ,10.0   ,9.79   ,2.625034426000308  ,0.07595628151917394],
[2.25    ,1.0    ,10.0   ,5.99   ,2.5538451341814254     ,0.08654861916579396],
[2.203   ,5.703  ,10.0   ,10.0   ,2.5229998036678203     ,0.08632527748001667],
[1.374   ,4.703  ,10.03  ,7.11   ,2.5520782235010606     ,0.08713511252325115],
[4.35    ,10.0   ,10.45  ,6.938  ,2.391843813383065  ,0.07045647178219852],
[4.867   ,10.0   ,4.766  ,7.332  ,11.966253685816561     ,10.439736072377116],
[3.234   ,9.17   ,10.484     ,4.555  ,3.464793102767614  ,0.5222274546485405],
[3.797   ,1.0    ,10.59  ,8.49   ,2.4228520332264734     ,0.05644599637011874],
[1.983   ,1.0    ,6.508  ,8.58   ,2.721301035077956  ,0.09255794329238595],
[3.29    ,7.184  ,11.69  ,7.83   ,2.4954472866207094     ,0.14119999778021255],
[0.2477  ,7.355  ,5.285  ,10.0   ,2.7430776664301186     ,0.07533276864832991],
[3.525   ,5.523  ,8.7    ,6.88   ,3.165797036000835  ,0.3286117704763439],
[0.2477  ,1.0    ,11.69  ,4.13   ,2.5718734331406137     ,0.11303499993632911],
[0.2404  ,0.76   ,7.72   ,5.926  ,2.7799387269988074     ,0.12337728364566593],
[4.25    ,0.76   ,11.414     ,1.0    ,4.863054728004499  ,0.21629626567292387],
[1.0        ,2.0         ,2.0    ,2.0    ,3.1270113454957027     ,0.5413061233313805],
[ 1.        ,  2.        ,  2.        ,  2.        ,  3.080895  ,        0.48222797]])

    bounds = np.array([[ 0.9326,  1.,  1.,  1.],
                       [10., 10., 10.195, 10.]])

    cand_as_list = optimiser(opt_hist, bounds).exploreqNEI()
    print(cand_as_list)
esantorella commented 1 year ago

Thanks for the reproducible example! This generates the error on my machine as well. Here is a more minimal example. This is an odd one since I can only reproduce the issue using the full opt_hist, with 38 elements. Adding or subtracting data makes the issue go away. It persists if we change fit_gyptorch_model, which is deprecated, to fit_gpytorch_mll.

esantorella commented 1 year ago

This is not an ideal solution, but might switching to float64 fix the issue? I am not getting this error when I change to double precision, and we recommend using double precision in BoTorch anyway to avoid numerical issues.

Balandat commented 1 year ago

+1 to using double rather than single precision, this will make a huge difference in particular for the heteroskedastic GP that can be quite prone to numerical difficulties.

That said, it would be good to understand why we end up with the retain_graph=True error. Could this have to do with some retries for fitting/optimizing that we're doing under the hood where we (or GPyTorch) may not be handling things in a completely clean fashion? (this is the kind of thing that could potentially be related to some of the memory leaks we've been seeing with jit compilation in pyro...)

esantorella commented 1 year ago

This does indeed happen upon retrying optimization, in _fit_fallback from botorch/fit.py. I'm looking into it right now. This involves some checkpoints and multiple context managers so I would not be surprised if there is something we aren't cleaning up perfectly.

esantorella commented 1 year ago

I have not managed to fully figure out what is going on here, but here's a (relatively) small repro that attempts to isolate the error by stripping fit_gpytorch_mll down to its most relevant lines.

What happens is: 1) We set up an mll and call fit_gpytorch_mll 2) fit_gpytorch_mll calls _fit_fallback, which tries scipy_minimize, potentially re-trying if it fails 3) On the first try, calling the closure fails with a numerical error. This means that we start callling the mll, but the call doesn't succeed. I think the fact that the call is not completed has to do with the issue here. 4) A context manager, module_rollback_ctx, rolls back the state of the mll. 5) scipy_minimize is tried again. It calls the closure, which succeeds on the first time and fails on the second.

import numpy as np
import torch

from botorch.models import HeteroskedasticSingleTaskGP
from botorch.models.transforms import Normalize, Standardize
from botorch.optim.closures.core import (
    ForwardBackwardClosure,
    NdarrayOptimizationClosure,
)
from botorch.optim.closures.model_closures import get_loss_closure
from botorch.optim.utils import get_parameters
from botorch.utils.context_managers import module_rollback_ctx
from gpytorch.mlls import ExactMarginalLogLikelihood

def get_mll():

    dataset = np.array(
        [
            [1.0, 5.32, 8.62, 1.0, 3.469931521995022, 0.15084671996718488],
            [1.01, 3.508, 1.056, 4.44, 3.0313060119951842, 0.142843038037407],
        ]
    )

    torch_kwargs = {"dtype": torch.float64, "device": torch.device("cpu")}

    X = torch.tensor(dataset[:, :-2], **torch_kwargs)
    Y = torch.tensor(dataset[:, -2:-1], **torch_kwargs)
    Y_std = torch.tensor(dataset[:, -1:], **torch_kwargs)
    Y_var = Y_std**2

    model_kwargs = {
        "train_X": X,
        "train_Y": Y,
        "outcome_transform": Standardize(m=1),
        "input_transform": Normalize(d=X.shape[-1])
    }

    model2 = HeteroskedasticSingleTaskGP(train_Yvar=Y_var, **model_kwargs)
    mll2 = ExactMarginalLogLikelihood(
        likelihood=model2.likelihood,
        model=model2,
    )
    mll2 = mll2.to(X)
    return mll2

def main() -> None:
    """
    Demonstrates how `fit_gpytorch_mll` can raise a RuntimeError by
    stripping it down to the relevant lines.
    """

    mll = get_mll()

    # to repro the error in fewer lines, just call `fit_gpytorch_mll` here
    # the rest of this function pulls out the relevant lines of `fit_gpytorch_mll`

    mll.train()

    closure = ForwardBackwardClosure(
        forward=get_loss_closure(mll),
        parameters=get_parameters(mll, requires_grad=True),
    )

    # context manager rolls back state of `mll` at the end
    with module_rollback_ctx(mll, device=torch.device("cpu")):
        wrapped_closure = NdarrayOptimizationClosure(closure, closure.parameters)
        # mimic the numerical error
        wrapped_closure(np.full_like(wrapped_closure.state, float("nan")))

    # Second pass through the loop. `miminimize` calls the closure successfully
    # once and then errors the second time.
    closure()
    closure()

if __name__ == "__main__":
    main()
esantorella commented 1 year ago

The code I posted above no longer generates the error, but the original example still does, as does this minimal repro.

I was hoping that the fix in #1635 would fix this as well, but this appears to be a separate issue.