pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.09k stars 400 forks source link

[Bug] NotPSDError at fit_gpytorch_model step or during optimize_acqf_mixed #1529

Closed npielawski closed 1 year ago

npielawski commented 1 year ago

🐛 Bug

I am running a multi-fidelity bayesian optimization to find a specific quaternion that maximizes a score function using qMultiFidelityLowerBoundMaxValueEntropy. In order to do that I had to warp a Matérn kernel to account for the space of quaternions. I doubled check the equations and even tried with the regular Matérn kernel, but I always get the NotPSDError exception. I can get the exception during the fitting of the parameters or during the acquisition function search. I know that my score is deterministic so I used a FixedNoiseMultiFidelityGP with a noise of 1e-6 (maybe that's too small? I tried 1e-4 and 1e-5 too).

I noticed that sometimes the learned parameters are too small (e.g. outputscale (ScaleKernel), or lengthscale (MaternKernel) is 1e-10) and I believe that can be a source of NotPSDError, so I constrained the parameters to be >1e-2/1e-3.

I tried to remove the LinearTruncatedFidelityKernel so that's not causing the problem either. I don't think the problem is coming from the individual kernels per se since removing/replacing them doesn't solve the issue.

I changed the source code of cholesky.py so that it saves the covariance function when an exception is thrown. When the fitting failes, as an example, the covariance is a 28x28 float32 matrix that unfortunately has negative eigenvalues. Here (cov_fail_fitting.npz) is some properties of the covariance matrix (failed during fitting):

And a plot of the covariance matrix:

image

Another example is a failure during the aquisition function search (cov_fail_acq.npz). The matrix is a 3x3:

array([[0.6930812 , 0.00274023, 0.23067696],
       [0.00274023, 0.6930812 , 0.6929812 ],
       [0.23067696, 0.6929812 , 0.6930812 ]], dtype=float32)
image

The matrices are available here: covariance_matrices.zip

I am not sure where to look now in order to debug and find what causes the eigenvalues to be negative. I saw on different issues that it can be due to high dimensional searches but quaternions are only 4 dimensional, and the cube representation is only 3D (since unit quaternions can be represented with 3 scalars). Any suggestion or ideas is helpful. Thanks!

To reproduce

Code snippet to reproduce The code base is pretty big, feel free to ask if you need any other piece of code.

class PeriodicGP3D(FixedNoiseMultiFidelityGP, GPyTorchModel):
    """Creates a periodic gaussian process with a Matern kernel."""

    # Model outputs a single value
    _num_outputs = 1

    def __init__(
        self,
        nu: float,
        X_train: torch.Tensor,
        y_train: torch.Tensor,
        noise: float,
        n_fidelity: int = 1,
    ):
        """Prepares the mean and kernels to use.

        Parameters
        ----------
        nu : float
            The nu parameter of the Matern kernel. Should be in {0.5, 1.5, 2.5}.
        X_train : torch.Tensor[N, d]
            The training set input samples.
        y_train : torch.Tensor[N, 1]
            The training set ground-truth samples.
        noise : float
            The noise level to use for the data points.
        n_fidelity : int
            The number of fidelities to use.
        """
        noise_vector = torch.ones_like(y_train) * noise
        super().__init__(X_train, y_train, noise_vector, n_fidelity)
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(
            LinearTruncatedFidelityKernel(
                fidelity_dims=[3],
                dimension=3,
                covar_module_unbiased=PeriodicMaternKernel3D(
                    nu=nu,
                    lengthscale_prior=GammaPrior(3.0, 6.0),
                    dtype=X_train.dtype,
                    device=X_train.device,
                ),
                covar_module_biased=PeriodicMaternKernel3D(
                    nu=nu,
                    lengthscale_prior=GammaPrior(6.0, 2.0),
                    dtype=X_train.dtype,
                    device=X_train.device,
                ),
                power_prior=GammaPrior(3.0, 3.0),
            ),
            outputscale_constraint=GreaterThan(1e-3),
        )
        self.to(X_train)

    def forward(self, x: torch.Tensor) -> MultivariateNormal:
        """..."""
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

Stack trace/error message

Here it is when the NotPSDError is raised during the acq search:
Traceback (most recent call last):
  File "/apps/Arch/software/Python/3.9.5-GCCcore-10.3.0/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/apps/Arch/software/Python/3.9.5-GCCcore-10.3.0/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/cephyr/users/nicolasp/Alvis/globalign/examples/organoid.py", line 142, in <module>
    main()
  File "/cephyr/users/nicolasp/Alvis/globalign/examples/organoid.py", line 113, in main
    for i, (fidelity, orientation) in enumerate(opt):
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign/optimizers/composite_optimizer.py", line 74, in __next__
    return self.current_child.__next__()
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign/optimizers/optimizer.py", line 119, in __next__
    return self.sample(self.step - 1)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign/optimizers/gibbon_optimizer_3d.py", line 298, in sample
    angles, fidelities, _, _ = self.sample_acquisition(idx)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign/optimizers/gibbon_optimizer_3d.py", line 328, in sample_acquisition
    self.fit_model(X_train, y_train)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign/optimizers/gibbon_optimizer_3d.py", line 277, in fit_model
    self.qAcqf = qMultiFidelityLowerBoundMaxValueEntropy(
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/botorch/acquisition/max_value_entropy_search.py", line 720, in __init__
    super().__init__(
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/botorch/acquisition/max_value_entropy_search.py", line 367, in __init__
    self.set_X_pending(X_pending)  # this did not happen in the super constructor
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/botorch/acquisition/max_value_entropy_search.py", line 392, in set_X_pending
    super().set_X_pending(X_pending)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/botorch/acquisition/max_value_entropy_search.py", line 152, in set_X_pending
    self._sample_max_values(num_samples=self.num_mv_samples, X_pending=X_pending)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/botorch/acquisition/max_value_entropy_search.py", line 291, in _sample_max_values
    self.posterior_max_values = sample_max_values(
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/botorch/acquisition/max_value_entropy_search.py", line 911, in _sample_max_value_Gumbel
    posterior = model.posterior(candidate_set, posterior_transform=posterior_transform)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/botorch/models/gpytorch.py", line 346, in posterior
    mvn = self(X)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/gpytorch/models/exact_gp.py", line 320, in __call__
    predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction(full_mean, full_covar)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py", line 272, in exact_prediction
    self.exact_predictive_mean(test_mean, test_train_covar),
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py", line 288, in exact_predictive_mean
    res = (test_train_covar @ self.mean_cache.unsqueeze(-1)).squeeze(-1)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/gpytorch/models/exact_prediction_strategies.py", line 239, in mean_cache
    mean_cache = train_train_covar.evaluate_kernel().solve(train_labels_offset).squeeze(-1)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/linear_operator/operators/_linear_operator.py", line 2189, in solve
    return func.apply(self.representation_tree(), False, right_tensor, *self.representation())
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/linear_operator/functions/_solve.py", line 53, in forward
    solves = _solve(linear_op, right_tensor)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/linear_operator/functions/_solve.py", line 17, in _solve
    return linear_op.cholesky()._cholesky_solve(rhs)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/linear_operator/operators/_linear_operator.py", line 1227, in cholesky
    chol = self._cholesky(upper=False)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/linear_operator/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/linear_operator/operators/_linear_operator.py", line 483, in _cholesky
    cholesky = psd_safe_cholesky(evaluated_mat, upper=upper).contiguous()
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/linear_operator/utils/cholesky.py", line 65, in psd_safe_cholesky
    L = _psd_safe_cholesky(A, out=out, jitter=jitter, max_tries=max_tries)
  File "/cephyr/users/nicolasp/Alvis/globalign/globalign-env/lib/python3.9/site-packages/linear_operator/utils/cholesky.py", line 47, in _psd_safe_cholesky
    raise NotPSDError(f"Matrix not positive definite after repeatedly adding jitter up to {jitter_new:.1e}.")
linear_operator.utils.errors.NotPSDError: Matrix not positive definite after repeatedly adding jitter up to 1.0e-04.

System information

Please complete the following information:

Balandat commented 1 year ago

I see you are using float32 data type - this is often problematic when working with GP models b/c the involved covariances (as you painfully discovered) are often close to singular. My first suggestion would be to switch over to use torch.double data type for all your data, that should hopefully resolve a lot of these numerical issues (it's just a lot easier to get into numerically troublesome territory with single precision). Let me know if that doesn't help!

npielawski commented 1 year ago

Thanks for your reply!

I switched to float64 now, but it still doesn't work. I changed the following tensors:

They are in this format from the beginning of the experiment to the end, no casting in between.

I also checked that the covariance that failed is indeed a float64 and it is. The determinant of that covariance matrix was -5.019655123536972e-17, which is very close to zero, I don't know if that plays a role as well besides the — once again — single negative eigenvalue that prevents the cholesky decomposition (here -1.43235979e-04, not too close to 0).

Balandat commented 1 year ago

Could you post a full repro of this behavior? I can see that this is happening when sampling the max values during the computation of the acquisition function, but it's hard to tell what kind of settings are being used without the actual code.

Basically what happens in qMultiFidelityLowerBoundMaxValueEntropy is that we draw posterior samples of the latent function at a relatively dense grid to produce approximate samples of the max value - this does not involve the noise level of the likelihood, so if the lengthscales of the model are large, the discrete sampling locations are close by (possibly b/c there are many points) then it's not too surprising that this may run into some numerical issues.

One quick thing you could try is increasing the maximum amount of jitter that gpytorch attempts in psd_safe_cholesky:

with linear_operator.cholesky_jitter(double=1e-6):
   ...

or

with linear_operator.cholesky_max_tries(5):
    ...

The first of the two sets the starting jitter value for double tensors to 1e-6, the second increases the number of tries to 5 (in each try the jitter value is increased by an order of magnitude).

(for more context, these are being invoked here: https://github.com/cornellius-gp/linear_operator/blob/main/linear_operator/utils/cholesky.py#L28-L31)

Doing so will cost some precision, but it may be ok if this is happening only occasionally. If this is still an issue then it may be worth considering some other approaches (e.g. using RFF/Decoupled sampling for generating max value samples) or potentially using a different acquisition function.

npielawski commented 1 year ago

I looked more into it and I could fix the issue in the end. The mistake was mine in the end. The kernel I used was correctly implemented but later proven wrong in a new publication (not PSD if the lengthscale is too big) that showed the proper way of computing it. Using double precision is necessary, however. Thanks for your help!