cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.46k stars 546 forks source link

[Bug] Incompatibility between delta dsitribution and unwhitened strategy in gpytorch.variational for fixed inducing points #2410

Open Otruffinet opened 9 months ago

Otruffinet commented 9 months ago

🐛 Bug

Hi,

I stumbled upon an unclear bug while trying to use an UnwhitenedVariationalStrategy with inducing locations fixed at the training points : the program crashed during training, raising an unspecified RuntimeError. After a bit of looking around, I found that this was due to my use of a DeltaVariationalDistribution instead of the standard CholeskyDistribution. Not only switching to a CholeskyDistribution solves the bug, but I can see in the code why it is the case : in the call method of _variational_strategy.py, we can see at line 309 that if variational_dist_u is a delta distribution, the forward method of the UnwhitenedVariationalStrategy is called with variational_inducing_covar=None, which directly triggers the exception in the case of fixed inducing points equal to the training points.

I don't know if this behavior if desired, as I have not investigated whether this usage of the methods makes sense mathematically. But I feel like the error message should be clearer anyway, as a simple change of options can solve the problem.

To reproduce

Code snippet to reproduce import gpytorch.means import torch torch.set_default_dtype(torch.float64) import gpytorch as gp

class VariationalMultitaskGPModel(gpytorch.models.ApproximateGP): def init( self, train_x, n_latents, n_tasks, distrib, mean_type=gpytorch.means.ConstantMean, kernel_type=gpytorch.kernels.RBFKernel, **kwargs ):

    inducing_points = train_x.unsqueeze(-1)
    var_dist = distrib(inducing_points.size(-2), batch_shape=torch.Size([n_latents]))
    strategy = gpytorch.variational.UnwhitenedVariationalStrategy(self, inducing_points,
                                                                  var_dist,
                                                                  learn_inducing_locations=False)

    variational_strategy = gpytorch.variational.LMCVariationalStrategy(
        strategy,
        num_tasks=n_tasks,
        num_latents=n_latents,
        latent_dim=-1)

    super().__init__(variational_strategy)

    self.covar_module = kernel_type(batch_shape=torch.Size([n_latents]))
    self.mean_module = mean_type(batch_shape=torch.Size([n_latents]))

def forward( self, x ):
    mean_x = self.mean_module(x)
    covar_x = self.covar_module(x)
    return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

n_points = 50 n_tasks = 10 n_lat = 2 mu_noise = 1e-2 min_scale, max_scale = 0.1, 0.5

X_train = torch.linspace(-1, 1, n_points) X_test = 2 torch.rand(500) - 1 X = torch.cat([X_train, X_test], dim=0) H_true = torch.randn(size=(n_lat, n_tasks)) ker_list = [gp.kernels.MaternKernel() for i in range(n_lat)] lscales = torch.linspace(min_scale, max_scale, n_lat) for i in range(n_lat): ker_list[i].lengthscale = lscales[i] lat_gp_dist = [gp.distributions.MultivariateNormal(torch.zeros_like(X), kernel(X)) for kernel in ker_list] gp_vals = torch.stack([dist.sample() for dist in lat_gp_dist]) # + torch.stack(noises) Y_sig = gp_vals.T @ H_true (1 - mu_noise)

noise_levels = torch.rand(n_tasks) + 0.1 noise_dist = [gp.distributions.MultivariateNormal(torch.zeros_like(X), noise_levels[i] torch.eye(len(X))) for i in range(n_tasks)] # homosk noise noise_vals = torch.stack([dist.sample() for dist in noise_dist]) Y_noise = noise_vals.T mu_noise Y = Y_sig + Y_noise Y_train, Y_test = Y[:n_points], Y[n_points:]

mean_type = gpytorch.means.ConstantMean kernel_type = gpytorch.kernels.MaternKernel likelihood = gp.likelihoods.MultitaskGaussianLikelihood(num_tasks=n_tasks, rank=0) var_dist = gpytorch.variational.CholeskyVariationalDistribution model = VariationalMultitaskGPModel(X_train, n_tasks=n_tasks, n_latents=n_lat, distrib=var_dist, mean_type=mean_type, kernel_type=kernel_type)

model.train() likelihood.train() mll = gp.mlls.VariationalELBO(likelihood, model, num_data=n_points) optimizer = torch.optim.AdamW([{'params': model.parameters()}, {'params': likelihood.parameters()}], lr=1e-2)

n_iter = 4000 patience = 500 loss_tresh = 1e-2 plateau_id = 0 for i in range(n_iter): optimizer.zero_grad() with gp.settings.cholesky_jitter(1e-5): output_train = model(X_train) loss = -mll(output_train, Y_train) if i%100==0: print(loss.item()) loss.backward() optimizer.step() loss = loss if i>0 and torch.abs(last_loss - loss) < loss_tresh: plateau_id += 1 if plateau_id > patience : break last_loss = loss

To reproduce (comments)

No evaluation block was provided, as the error already arises during training.

Replacing the DeltaVariationalStrategy by a CholeskyVariationalStrategy solves the problem.

Stack trace/error message Traceback (most recent call last): File "/.../bug_delta.py", line 83, in output_train = model(X_train) File "/.../anaconda3/envs/gp/lib/python3.10/site-packages/gpytorch/models/approximate_gp.py", line 108, in call return self.variational_strategy(inputs, prior=prior, kwargs) File "/.../anaconda3/envs/gp/lib/python3.10/site-packages/gpytorch/variational/lmc_variational_strategy.py", line 197, in call latent_dist = self.base_variational_strategy(x, prior=prior, kwargs) File "/.../anaconda3/envs/gp/lib/python3.10/site-packages/gpytorch/variational/_variational_strategy.py", line 349, in call return super().call( File "/.../anaconda3/envs/gp/lib/python3.10/site-packages/gpytorch/module.py", line 31, in call outputs = self.forward(*inputs, **kwargs) File "/.../anaconda3/envs/gp/lib/python3.10/site-packages/gpytorch/variational/unwhitened_variational_strategy.py", line 132, in forward raise RuntimeError RuntimeError

Expected Behavior

Either a mathematical workaround is found (for instance by assigning some relevant value to variational_inducing_covar), or a clear error message is returned, stating that the bug stems from an incompatibility between the delta distribution and unwhitened strategy for fixed inducing points.

System information

Please complete the following information: