Multitask Beta Likelihood for SVI on Multi-Output GP #2381

eirikbaekkelund commented 11 months ago

🐛 Bug - Beta Likelihood for SVI Multitask GP

To reproduce

When fitting an approximate GP using SVI, I have a model working with the multitask Gaussian Likelihood, however, when switching to a Beta Likelihood, the optimisation routine does not seem to learn much. The scale parameter should work fine for all the tasks passed, so my intuition is that there is a need for the batch shape to be accounted for that it currently doesn't seem to do. The output from the beta likelihood is of the form (n_samples, n_points, n_tasks), so we have e.g. an output of (10,100,5). So, how does this information get stored of sample dim = 0, batch_dim = 1, and event dim = 2 get passed appropriately to the variational elbo to be minimized? I.e., how does the VI class consider event shape / batch shape / sample shape? And would a new likelihood for the multitask beta need to be implemented?

Code snippet to reproduce

import torch
import gpytorch
import numpy as np
from gpytorch.variational import MeanFieldVariationalDistribution
from gpytorch.kernels import (MaternKernel, 
x_train = torch.linspace(0, 1, 100)

y_train = torch.stack([
    0.2*torch.sin(x_train * (2 * np.pi)) + torch.randn(x_train.size()) * 0.05 + 0.5,
    0.2*torch.cos(x_train * (2 * np.pi)) + torch.randn(x_train.size()) * 0.05 + 0.5,
    0.2*torch.sin(x_train * (2 * np.pi)) + 0.2 * torch.cos(x_train * (2 * np.pi)) + torch.randn(x_train.size()) * 0.05 + 0.5,
    0.2*torch.cos(x_train * (2 * np.pi)) + torch.randn(x_train.size()) * 0.05 + 0.5,
], -1)

matern_base = MaternKernel(nu=3/2, 
                      lengthscale_prior=gpytorch.priors.GammaPrior(2, 8),
periodic = PeriodicKernel( # period_length_prior=gpytorch.priors.GammaPrior(3, 2),
                           # period_length_constraint=gpytorch.constraints.Positive(),
scaled_periodic = ScaleKernel(periodic,
                                # outputscale_prior=gpytorch.priors.GammaPrior(5, 1),
                                # outputscale_constraint=gpytorch.constraints.Positive(),
scaled_matern = ScaleKernel(matern_base, 
                            # outputscale_prior=gpytorch.priors.GammaPrior(5, 2),
                            # outputscale_constraint=gpytorch.constraints.Interval(0.1, 1),
product_kernel_matern_periodic = ScaleKernel(periodic * matern_base,
                            #  outputscale_prior = gpytorch.priors.GammaPrior(5, 2),
                            #  outputscale_constraint=gpytorch.constraints.Positive(),

quasi_periodic_matern = AdditiveKernel(product_kernel_matern_periodic, scaled_matern)

class MultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self):

        # MeanField constructs a variational distribution for each output dimension
        variational_distribution = MeanFieldVariationalDistribution(
            x_train.size(-1), batch_shape=torch.Size([num_latents]),jitter=1e-2

        # LMC constructs MultitaskMultivariateNormal from the base var dist
        variational_strategy = gpytorch.variational.LMCVariationalStrategy(
                self, x_train, variational_distribution, learn_inducing_locations=False, jitter_val=1e-2


        # batch for different hypers for each output dimension
        self.mean_module = gpytorch.means.ZeroMean(batch_shape=torch.Size([num_latents]))
        self.covar_module =  quasi_periodic_matern

    def forward(self, x):

        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)

        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

model = MultitaskGPModel()
# works for likelihood1 not for likelihood2
likelihood1 = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=y_train.size(-1))
likelihood2 = gpytorch.likelihoods.BetaLikelihood(
                scale_prior=gpytorch.priors.GammaPrior(30, 2),
                scale_constraint=gpytorch.constraints.Interval(10, 25))


optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood2.parameters()},
], lr=0.1)

# SVI training loop for minimizing ELBO E_q(f) [log p(y|f)] - KL[q(f) || p(f)]
# where q(f) is the variational distribution and p(f) is the prior

mll = gpytorch.mlls.VariationalELBO(likelihood2, model, num_data=y_train.size(0))

n_iter = 300
print_freq = n_iter // 10

for i in range(n_iter + 1):
    # Within each iteration, we will go over each minibatch of data
    output = model(x_train)
    loss = -mll(output, y_train).mean()

    if i % print_freq == 0:
        print('Iter %d/%d - Loss: %.3f' % (i, n_iter, loss.item()))

Optimisation tracking for each

Iter 0/300 - Loss: 4.696
Iter 30/300 - Loss: -0.110
Iter 60/300 - Loss: -2.532
Iter 90/300 - Loss: -3.476
Iter 120/300 - Loss: -3.643
Iter 150/300 - Loss: -3.816
Iter 180/300 - Loss: -3.458
Iter 210/300 - Loss: -3.936
Iter 240/300 - Loss: -3.965
Iter 270/300 - Loss: -3.429
Iter 300/300 - Loss: -4.030


Iter 0/300 - Loss: 0.252
Iter 30/300 - Loss: 0.010
Iter 60/300 - Loss: -0.024
Iter 90/300 - Loss: -0.026
Iter 120/300 - Loss: -0.026
Iter 150/300 - Loss: -0.026
Iter 180/300 - Loss: -0.026
Iter 210/300 - Loss: -0.026
Iter 240/300 - Loss: -0.026
Iter 270/300 - Loss: -0.026
Iter 300/300 - Loss: -0.026

Expected Behavior

Training these on a simple approximate GP with the same settings yield good fits but does not carry over to the multitask case. I couldn't find any post dealing with the same issue, so was wondering what causes this behaviour and possible fixes.

Resulting image from fitting with MultitaskGaussianLikelihood:


Resulting image from fitting with Beta Likelihood: 90d73f1e-ae1f-4dd6-b34a-c07e40ddb17a

spectraldani commented 11 months ago

As the implementation of VariationalELBO._log_likelihood_term (Code) expects the last dimension of likelihood.expected_log_prob to be the data/batch dimension, it seems like that this is due to a difference in implementation of the likelihood.expected_log_prob method from MultitaskGaussianLikelihood and BetaLikelihood.

In the source code for MultitaskGaussianLikelihood, it actually detects that the GP output is multi-task and sums the task dimensions. Meanwhile, as for _OneDimensionalLikelihood.expected_log_prob, there is no additional adjustment which means that the return will have shapes like (n_data, n_tasks), breaks VariationalELBO._log_likelihood_term expectation that the last dimension is the data dimension.

gpleiss commented 11 months ago

The output from the beta likelihood is of the form (n_samples, n_points, n_tasks), so we have e.g. an output of (10,100,5).

As @spectraldani mentions, the beta likelihood is not designed for multitask GPs. The expected shape is of the form n_samples, <batch_dimension>, n_points. It is a OneDimensionalLikelihood

And would a new likelihood for the multitask beta need to be implemented?

It would. I'm not sure we'd be open to it as a PR right now, unless there's potentially a way to nicely generalize all of our one dimensional likelihoods into multitask likelihoods.

eirikbaekkelund commented 11 months ago

This does the job

class MultitaskBetaLikelihood(gpytorch.likelihoods.BetaLikelihood):
    A multitask BetaLikelihood that supports multitask GP regression.
    def __init__(
        n_tasks: int,
        batch_shape: torch.Size = torch.Size([]),
        scale_prior: Optional[Prior] = None,
        scale_constraint: Optional[Interval] = None,
    ) -> None:

        if scale_constraint is None:
            scale_constraint = Positive()

        self.raw_scale = torch.nn.Parameter(torch.ones(*batch_shape, 1, n_tasks))
        if scale_prior is not None:
            self.register_prior("scale_prior", scale_prior, lambda m: m.scale, lambda m, v: m._set_scale(v))

        self.register_constraint("raw_scale", scale_constraint)

    def expected_log_prob(self, observations, function_dist, *args, **kwargs):
        ret = super().expected_log_prob(observations, function_dist, *args, **kwargs)

        num_event_dim = len(function_dist.event_shape)

        if num_event_dim > 1:  # Do appropriate summation for multitask likelihood
            ret = ret.sum(list(range(-1, -num_event_dim, -1)))
        return ret