huawei-noah / HEBO

Bayesian optimisation & Reinforcement Learning library developped by Huawei Noah's Ark Lab
3.2k stars 572 forks source link

Issues with Cholesky decompositions on simple benchmark #61

Open dimitri-rusin opened 10 months ago

dimitri-rusin commented 10 months ago

I get some issues with the Cholesky Decompositions. Here's how to reproduce it:

  1. Go to: https://colab.research.google.com/drive/1XftMKU7-tWj0cdWjH7XsfiDPBIKXAWZk#scrollTo=OuypIJ7do1qi
  2. Run all cells.
  3. Then, either we get the exception:
    NanError: cholesky_cpu: 3716 of 3721 elements of the torch.Size([61, 61]) tensor are NaN.

    or

    NotPSDError: Matrix not positive definite after repeatedly adding jitter up to 1.0e-04.

What matrix is being decomposed here? Can I influence or change this matrix using some hyperparameters? What can I do?

Thank you! exc

Takui9 commented 9 months ago

same problem. got issues when using contextual HEBO

AntGro commented 9 months ago

To improve the stability you can add upper constraints on the kernel lengthscales and modify HEBO/hebo/models/gp/gp_utils.py replacing it with this version:

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.

# This program is free software; you can redistribute it and/or modify it under
# the terms of the MIT license.

# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.

import numpy as np
import torch
import torch.nn as nn
from torch import FloatTensor, LongTensor

from gpytorch.kernels import MaternKernel, ScaleKernel, ProductKernel
from gpytorch.priors  import GammaPrior
from gpytorch.constraints.constraints import LessThan

from ..layers import EmbTransform

class DummyFeatureExtractor(nn.Module):
    def __init__(self, num_cont, num_enum, num_uniqs = None, emb_sizes = None):
        super().__init__()
        self.num_cont  = num_cont
        self.num_enum  = num_enum
        self.total_dim = num_cont
        if num_enum > 0:
            assert num_uniqs is not None
            self.emb_trans  = EmbTransform(num_uniqs, emb_sizes = emb_sizes)
            self.total_dim += self.emb_trans.num_out

    def forward(self, x : FloatTensor, xe : LongTensor):
        x_all = x
        if self.num_enum > 0:
            x_all = torch.cat([x, self.emb_trans(xe)], dim = 1)
        return x_all

def default_kern(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x = 1000):
    if fe is None:
        has_num  = x  is not None and x.shape[1]  > 0
        has_enum = xe is not None and xe.shape[1] > 0
        kerns    = []
        if has_num:
            ard_num_dims = x.shape[1] if ard_kernel else None
            kernel       = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = torch.arange(x.shape[1]),
                                        lengthscale_constraint=LessThan(5))
            if ard_kernel:
                lscales = kernel.lengthscale.detach().clone().view(1, -1)
                for i in range(x.shape[1]):
                    idx = np.random.choice(x.shape[0], min(x.shape[0], max_x), replace = False)
                    lscales[0, i] = torch.pdist(x[idx, i].view(-1, 1)).median().clamp(min = 0.02)
                kernel.lengthscale = lscales
            kerns.append(kernel)
        if has_enum:
            kernel = MaternKernel(nu = 1.5, active_dims = torch.arange(x.shape[1], total_dim),
                                        lengthscale_constraint=LessThan(5))
            kerns.append(kernel)
        final_kern = ScaleKernel(ProductKernel(*kerns), outputscale_prior = GammaPrior(0.5, 0.5))
        final_kern.outputscale = y[torch.isfinite(y)].var()
        return final_kern
    else:
        if ard_kernel:
            kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim,
                                        lengthscale_constraint=LessThan(5)))
        else:
            kernel = ScaleKernel(MaternKernel(nu = 1.5))
        kernel.outputscale = y[torch.isfinite(y)].var()
        return kernel

I've tested it on the example provided in the original issue comment and it runs without errors.

After further investigation we'll include this to the repo directly.

AntGro commented 9 months ago

@dimitri-rusin Actually we also need to specify a lower bound to the lengthscales to prevent NaN or Inf. So I replace LessThan by Interval.

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.

# This program is free software; you can redistribute it and/or modify it under
# the terms of the MIT license.

# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.

import numpy as np
import torch
import torch.nn as nn
from torch import FloatTensor, LongTensor

from gpytorch.kernels import MaternKernel, ScaleKernel, ProductKernel
from gpytorch.priors  import GammaPrior
from gpytorch.constraints.constraints import Interval

from ..layers import EmbTransform

class DummyFeatureExtractor(nn.Module):
    def __init__(self, num_cont, num_enum, num_uniqs = None, emb_sizes = None):
        super().__init__()
        self.num_cont  = num_cont
        self.num_enum  = num_enum
        self.total_dim = num_cont
        if num_enum > 0:
            assert num_uniqs is not None
            self.emb_trans  = EmbTransform(num_uniqs, emb_sizes = emb_sizes)
            self.total_dim += self.emb_trans.num_out

    def forward(self, x : FloatTensor, xe : LongTensor):
        x_all = x
        if self.num_enum > 0:
            x_all = torch.cat([x, self.emb_trans(xe)], dim = 1)
        return x_all

def default_kern(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x = 1000):
    if fe is None:
        has_num  = x  is not None and x.shape[1]  > 0
        has_enum = xe is not None and xe.shape[1] > 0
        kerns    = []
        if has_num:
            ard_num_dims = x.shape[1] if ard_kernel else None
            kernel       = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = torch.arange(x.shape[1]),
                                        lengthscale_constraint=Interval(1e-5, 5))
            if ard_kernel:
                lscales = kernel.lengthscale.detach().clone().view(1, -1)
                for i in range(x.shape[1]):
                    idx = np.random.choice(x.shape[0], min(x.shape[0], max_x), replace = False)
                    lscales[0, i] = torch.pdist(x[idx, i].view(-1, 1)).median().clamp(min = 0.02)
                kernel.lengthscale = lscales
            kerns.append(kernel)
        if has_enum:
            kernel = MaternKernel(nu = 1.5, active_dims = torch.arange(x.shape[1], total_dim),
                                        lengthscale_constraint=Interval(1e-5, 5))
            kerns.append(kernel)
        final_kern = ScaleKernel(ProductKernel(*kerns), outputscale_prior = GammaPrior(0.5, 0.5))
        final_kern.outputscale = y[torch.isfinite(y)].var()
        return final_kern
    else:
        if ard_kernel:
            kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim,
                                        lengthscale_constraint=Interval(1e-5, 5)))
        else:
            kernel = ScaleKernel(MaternKernel(nu = 1.5))
        kernel.outputscale = y[torch.isfinite(y)].var()
        return kernel
kegl commented 8 months ago

Actually it's this, right? default_kern_rd is also needed.

# Copyright (C) 2020. Huawei Technologies Co., Ltd. All rights reserved.

# This program is free software; you can redistribute it and/or modify it under
# the terms of the MIT license.

# This program is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
# PARTICULAR PURPOSE. See the MIT License for more details.

import numpy as np
import torch
import torch.nn as nn
from torch import FloatTensor, LongTensor

from gpytorch.kernels import MaternKernel, ScaleKernel, ProductKernel
from gpytorch.priors  import GammaPrior
from gpytorch.constraints.constraints import Interval

from ..layers import EmbTransform

class DummyFeatureExtractor(nn.Module):
    def __init__(self, num_cont, num_enum, num_uniqs = None, emb_sizes = None):
        super().__init__()
        self.num_cont  = num_cont
        self.num_enum  = num_enum
        self.total_dim = num_cont
        if num_enum > 0:
            assert num_uniqs is not None
            self.emb_trans  = EmbTransform(num_uniqs, emb_sizes = emb_sizes)
            self.total_dim += self.emb_trans.num_out

    def forward(self, x : FloatTensor, xe : LongTensor):
        x_all = x
        if self.num_enum > 0:
            x_all = torch.cat([x, self.emb_trans(xe)], dim = 1)
        return x_all

def default_kern(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x = 1000):
    if fe is None:
        has_num  = x  is not None and x.shape[1]  > 0
        has_enum = xe is not None and xe.shape[1] > 0
        kerns    = []
        if has_num:
            ard_num_dims = x.shape[1] if ard_kernel else None
            kernel       = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = torch.arange(x.shape[1]),
                                        lengthscale_constraint=Interval(1e-5, 5))
            if ard_kernel:
                lscales = kernel.lengthscale.detach().clone().view(1, -1)
                for i in range(x.shape[1]):
                    idx = np.random.choice(x.shape[0], min(x.shape[0], max_x), replace = False)
                    lscales[0, i] = torch.pdist(x[idx, i].view(-1, 1)).median().clamp(min = 0.02)
                kernel.lengthscale = lscales
            kerns.append(kernel)
        if has_enum:
            kernel = MaternKernel(nu = 1.5, active_dims = torch.arange(x.shape[1], total_dim),
                                        lengthscale_constraint=Interval(1e-5, 5))
            kerns.append(kernel)
        final_kern = ScaleKernel(ProductKernel(*kerns), outputscale_prior = GammaPrior(0.5, 0.5))
        final_kern.outputscale = y[torch.isfinite(y)].var()
        return final_kern
    else:
        if ard_kernel:
            kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim,
                                        lengthscale_constraint=Interval(1e-5, 5)))
        else:
            kernel = ScaleKernel(MaternKernel(nu = 1.5))
        kernel.outputscale = y[torch.isfinite(y)].var()
        return kernel

def default_kern_rd(x, xe, y, total_dim = None, ard_kernel = True, fe = None, max_x = 1000, E=0.2):
    '''
    Get a default kernel with random decompositons. 0 <= E <=1 specifies random tree conectivity.
    '''
    kernels = []
    random_graph = get_random_graph(total_dim, E)
    for clique in random_graph:
        if fe is None:
            num_dims  = tuple(dim for dim in clique if dim < x.shape[1])
            enum_dims = tuple(dim for dim in clique if x.shape[1] <= dim < total_dim)
            clique_kernels = []
            if len(num_dims) > 0:
                ard_num_dims = len(num_dims) if ard_kernel else None
                num_kernel       = MaternKernel(nu = 1.5, ard_num_dims = ard_num_dims, active_dims = num_dims)
                if ard_kernel:
                    lscales = num_kernel.lengthscale.detach().clone().view(1, -1)
                    if len(num_dims) > 1 :
                        for dim_no, dim_name in enumerate(num_dims):
                            idx = np.random.choice(num_dims, min(len(num_dims), max_x), replace = False)
                            lscales[0, dim_no] = torch.pdist(x[idx, dim_name].view(-1, 1)).median().clamp(min = 0.02)
                    num_kernel.lengthscale = lscales
                clique_kernels.append(num_kernel)
            if len(enum_dims) > 0:
                enum_kernel = MaternKernel(nu = 1.5, active_dims = enum_dims)
                clique_kernels.append(enum_kernel)

            kernel = ScaleKernel(ProductKernel(*clique_kernels), outputscale_prior = GammaPrior(0.5, 0.5))
        else:
            if ard_kernel:
                kernel = ScaleKernel(MaternKernel(nu = 1.5, ard_num_dims = total_dim, active_dims=tuple(clique)))
            else:
                kernel = ScaleKernel(MaternKernel(nu = 1.5, active_dims=tuple(clique)))

        kernels.append(kernel)

    final_kern = ScaleKernel(AdditiveKernel(*kernels), outputscale_prior = GammaPrior(0.5, 0.5))
    final_kern.outputscale = y[torch.isfinite(y)].var()
    return final_kern
kegl commented 7 months ago

I tried this and it was not enough. What worked was catching the jitter fail and increasing jitter till it works.

    def fit(self, Xc : Tensor, Xe : Tensor, y : Tensor):
        Xc, Xe, y = filter_nan(Xc, Xe, y, 'all')
        self.fit_scaler(Xc, Xe, y)
        Xc, Xe, y = self.xtrans(Xc, Xe, y)

        assert(Xc.shape[1] == self.num_cont)
        assert(Xe.shape[1] == self.num_enum)
        assert(y.shape[1]  == self.num_out)

        self.Xc = Xc
        self.Xe = Xe
        self.y  = y

        n_constr = GreaterThan(self.noise_lb)
        n_prior  = LogNormalPrior(np.log(self.noise_guess), 0.5)
        self.lik = GaussianLikelihood(noise_constraint = n_constr, noise_prior = n_prior)
        self.gp  = GPyTorchModel(self.Xc, self.Xe, self.y, self.lik, **self.conf)

        self.gp.likelihood.noise  = max(1e-2, self.noise_lb)

        self.gp.train()
        self.lik.train()

        if self.optimizer.lower() == 'lbfgs':
            opt = torch.optim.LBFGS(self.gp.parameters(), lr = self.lr, max_iter = 5, line_search_fn = 'strong_wolfe')
        elif self.optimizer == 'psgld':
            opt = pSGLD(self.gp.parameters(), lr = self.lr, factor = 1. / y.shape[0], pretrain_step = self.num_epochs // 10)
        else:
            opt = torch.optim.Adam(self.gp.parameters(), lr = self.lr)
        mll = gpytorch.mlls.ExactMarginalLogLikelihood(self.lik, self.gp)
        for epoch in range(self.num_epochs):
            jitter = 10 ** -8
            cont = True
            while cont:
                cont = False
                cholesky_jitter._set_value(
                    double_value=jitter, float_value=100*jitter, half_value=10000*jitter)
                def closure():
                    dist = self.gp(self.Xc, self.Xe)
                    loss = -1 * mll(dist, self.y.squeeze())
                    opt.zero_grad()
                    loss.backward()
                    return loss
                try:
                    opt.step(closure)
                except:
                    jitter *= 10
                    cont = True
                    print(f'jitter = {jitter}')
            if self.verbose and ((epoch + 1) % self.print_every == 0 or epoch == 0):
                print('After %d epochs, loss = %g' % (epoch + 1, closure().item()), flush = True)
        self.gp.eval()
        self.lik.eval()

    def predict(self, Xc, Xe):
        Xc, Xe = self.xtrans(Xc, Xe)
        with gpytorch.settings.fast_pred_var(), gpytorch.settings.debug(False):
            jitter = 10 ** -8
            cont = True
            while cont:
                cont = False
                cholesky_jitter._set_value(
                    double_value=jitter, float_value=100*jitter, half_value=10000*jitter)
                try:
                    pred = self.gp(Xc, Xe)
                except:
                    jitter *= 10
                    cont = True
                    print(f'jitter = {jitter}')                
            if self.pred_likeli:
                pred = self.lik(pred)
            mu_  = pred.mean.reshape(-1, self.num_out)
            var_ = pred.variance.reshape(-1, self.num_out)
        mu  = self.yscaler.inverse_transform(mu_)
        var = var_ * self.yscaler.std**2
        return mu, var.clamp(min = torch.finfo(var.dtype).eps)
muazhari commented 3 weeks ago

Are there any updates and fixes? Until now, I am still getting the error, even though it is in a simple use case. Reproducible code: https://gist.github.com/muazhari/85b7469902cdb7b3fba49a065b212f40

NanError: cholesky_cpu: 225 of 225 elements of the torch.Size([15, 15]) tensor are NaN.