cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.58k stars 562 forks source link

[Question] Implementing multi-output multi-task approximate GP #1743

Open fleskovar opened 3 years ago

fleskovar commented 3 years ago

I am looking into implementing a model that produces multiple correlated output for multiple tasks (multi-task-multi-output - MTMO). For this type of model, I assume that the input tensor has the shape n x d+1 (d inputs plus the additional task index) while the output tensor has the shape n x o (where o is the number of correlated outputs). Additionally, all outputs are observed simultaneously but not all tasks are. The training data for this mode would look like this:

train_x = torch.cat([torch.rand(n_training, 1), (torch.rand(n_training,1)>0.5).float()], dim=1)
train_y = torch.stack([
    torch.sin(5.5 * train_x[:,0])* train_x[:,0]**2 * train_x[:,-1] + torch.cos(5.5 * train_x[:,0])* train_x[:,0]**2  * (1-train_x[:,-1]),
    torch.cos(5.5 * train_x[:,0])* train_x[:,0]*2 * train_x[:,-1] + torch.sin(5.5 * train_x[:,0])* train_x[:,0]*2 * (1-train_x[:,-1])
], axis = -1)

For the exact GP case, the model looks like this:

class MultiOutputMultiTaskGP(ExactGP,):

    def __init__(
        self,
        train_X: Tensor,
        train_Y: Tensor,
        likelihood: MultitaskGaussianLikelihood = None,
        rank = None,
        lik_rank = None,
    ) -> None:

        num_tasks = train_Y.shape[-1]
        batch_shape, ard_num_dims = train_X.shape[:-2], train_X.shape[-1]

        if lik_rank is None:
            lik_rank = rank

        self._validate_tensor_args(X=train_X, Y=train_Y)

        if likelihood is None:
            likelihood = MultitaskGaussianLikelihood(
                num_tasks=num_tasks, 
                rank=lik_rank if lik_rank is not None else 0,
            )

        super(MultiOutputMultiTaskGP, self).__init__(train_X, train_Y, likelihood)
        self._rank = rank if rank is not None else num_tasks

        self.mean_module = MultitaskMean(
            ConstantMean(), num_tasks=num_tasks
        )

        self.data_kernel = MaternKernel()
        self.task_kernel = IndexKernel(num_tasks=len(torch.unique(train_X[..., -1])))
        self.output_kernel = IndexKernel(num_tasks=num_tasks)

        self.to(train_X)

    def forward(self, x: Tensor) -> MultitaskMultivariateNormal:
        mean_x = self.mean_module(x)
        task_term = self.task_kernel(x[..., -1].long())
        data_and_task_x = self.data_kernel(x[..., :-1]).mul(task_term)
        output_x = self.output_kernel.covar_matrix
        covar_x = KroneckerProductLazyTensor(data_and_task_x, output_x)
        return MultitaskMultivariateNormal(mean_x, covar_x)

I want to implement an approximate version of this model by using the LMCVariationalStrategy but I am facing some issues:

Do you know how I should implement this model? Thanks

ianhill60 commented 3 years ago

I have a very similar question: Does Gpytorch support batching for correlated multitask regression? Would you use a MultitaskVariationalStrategy to do that? I stumbled across this tds article that may be helpful to you fleskovar: https://towardsdatascience.com/batched-multi-dimensional-gaussian-process-regression-with-gpytorch-3a6425185109) The author claims there isn't support yet for batching with multitask regression, but the model induction strategy and batching used in the docs example for Variational GPs w/ Multiple Outputs seems to get close to what I'm looking for. Thanks!

gpleiss commented 3 years ago

First of all, @fleskovar and @ianhill60 I am sorry for the very slow reply!

I am not sure if inducing points should include the task index. In the case the index has to be included, there does not seem to be a straight forward way to keep the values fixed while also learning the optimal location of the inducing points during training (learn_inducing_locations=True)

From a practical software perspective: this will probably require a different variational strategy. However, it seems like there are lots of requests for a similar Hadamard-style multi-task SVGP model, so I'll probably take a look at implementing that soon.

From a technical perspective: you'd probably want one set of inducing points per task.

fleskovar commented 3 years ago

Hi @gpleiss! No worries at all, after doing some additional reading I realized that, even though this approach would help me scale my model to bigger datasets, it is not exactly what I was meaning to do.

I am trying to find the optimal set of points that, given an already trained GP, would yield the best approximation of the original dataset (this is basically the same idea as in Dataset Distillation). I thought that the inducing points would yield this "compressed" representation of the dataset, but it seems that they are not meant to do that.

I have built a toy example where I train an ExactGP in the usual way and then I use a second optimizer to find new training points to reduce the mll with respect to the original dataset. The code looks like this:

import gpytorch
import torch
from botorch.models import SingleTaskGP
from botorch.models.multitask import KroneckerMultiTaskGP
from botorch.optim.fit import fit_gpytorch_torch, fit_gpytorch_scipy
from botorch.fit import fit_gpytorch_model
from gpytorch.mlls import ExactMarginalLogLikelihood
import numpy as np
import matplotlib.pyplot as plt
from torchviz import make_dot 

def make_data_task_0_distil():
    # Build training data
    x_task_0 = np.linspace(0, 1, 1000)
    y_task_0 = sum([
        -x_task_0+2,
        0.5* np.sin(10*x_task_0),
        0.05* np.sin(50*x_task_0),
        0.1* np.sin(50*(x_task_0+0.2))

    ]).reshape(-1, 1)

    x_task_0 = torch.tensor(x_task_0.astype(np.float32)).unsqueeze(1)
    y_task_0 = torch.tensor(y_task_0.astype(np.float32)).squeeze(-1)

    return x_task_0, y_task_0

train_X_full, train_Y_full = make_data_task_0_distil()

n_samples = 10  # For initial training
samples = torch.randint(0, train_X_full.shape[0], (n_samples,))
train_X = train_X_full[samples, :]
train_Y = train_Y_full[samples]

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_X, train_Y, likelihood)

model.likelihood.noise_covar.register_constraint("raw_noise", gpytorch.constraints.LessThan(torch.tensor(2e-4)))

# Find optimal model hyperparameters
model.train()
likelihood.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.1) 
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

training_iter = 100
for i in range(training_iter):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(train_X)
    # Calc loss and backprop gradients
    loss = -mll(output, train_Y)
    loss.backward()
    optimizer.step()

model.likelihood.noise_covar.noise = 1e-4
model.eval()
likelihood.eval()

# Distillation
n_compression_points = 10
n_inputs = 1
samples = torch.randint(0, train_X_full.shape[0], (n_compression_points,))
x_compressed = train_X_full[samples, :].clone().requires_grad_()
y_compressed = train_Y_full[samples].clone().requires_grad_()

optimizer = torch.optim.Adam([x_compressed, y_compressed], lr=0.1)  # Includes GaussianLikelihood parameters
mll = ExactMarginalLogLikelihood(model.likelihood, model)

training_iter = 100
for i in range(training_iter):    

    optimizer.zero_grad()    
    model.set_train_data(x_compressed, y_compressed, strict=False)

    output = likelihood(model(train_X_full))
    loss = -mll(output, train_Y_full)
    loss.backward()

    print('Iter %d/%d - Loss: %.3f' % (
        i + 1, training_iter, loss.item()
    ))

    optimizer.step()

Unfortunately, I am not able to get good results since y_compressed stay unchanged throughout the optimization process. After looking around a bit, I realized that y_compressed.grad is always None. I cannot seem to find a reason for this since I can see the tensor being used inside the mean_cache function of DefaultPredictionStrategy . Am I doing anything wrong or is there a better way to achieve this?

gpleiss commented 3 years ago

After looking around a bit, I realized that y_compressed.grad is always None. I cannot seem to find a reason for this since I can see the tensor being used inside the mean_cache function of DefaultPredictionStrategy . Am I doing anything wrong or is there a better way to achieve this?

Without actually running your code example, I think the solution is to add the following context manager to your training loop:

with gpytorch.settings.detach_test_caches(False):
    for i in range(training_iter):
        # ...

(See https://docs.gpytorch.ai/en/stable/settings.html#gpytorch.settings.detach_test_caches)

As a default, we usually detach the posterior caches from autograd, so that you don't run out of memory when making predictions. However, when you actually do want to compute gradients through your posterior, then you need this context manager to ensure that you are properly getting gradients through the entire posterior.

(This has been the source of other errors in the past - @jacobrgardner / @Balandat / @wjmaddox we should see if there's some way that we can raise a warning if someone tries back propagating through a posterior without this context manager.)

fleskovar commented 3 years ago

Thanks a lot @gpleiss , this seems to do the trick. However, I don't think I fully understand why it does.

I was taking a look at ExactGP to see if I could find where train_targets are used. If I understand correctly, the posterior mean and covariance are obtained from DefaultPredictionStrategy here:

# Make the prediction
with settings._use_eval_tolerance():
    predictive_mean, predictive_covar = self.prediction_strategy.exact_prediction(full_mean, full_covar)

Where full_mean is calculated by doing a forward pass with the train and test inputs (the train targets are not used). The methods exact_predictive_mean from DefaultPredictionStrategy uses mean_cache which seems to be calculated using the train targets:

train_labels_offset = (self.train_labels - train_mean).unsqueeze(-1)
mean_cache = train_train_covar.evaluate_kernel().inv_matmul(train_labels_offset).squeeze(-1)

Is this the step that allows to backprop through the train targets?

Thanks

jacobrgardner commented 3 years ago

@fleskovar yes, but a closely following line prevents it from working by default: https://github.com/cornellius-gp/gpytorch/blob/fc2053b0fc00517880fbc11adc7f5802242eec6a/gpytorch/models/exact_prediction_strategies.py#L232

The reason this is done is that otherwise making predictions with the model repeatedly would either need to be done in a torch.no_grad context, or rapidly run out of memory due to accumulating compute graphs.

@gpleiss I don't think it's very simple to add a warning here. The problem is that currently you can backprop w.r.t the test inputs just fine with the caches detached, and that's a much more common operation (e.g., differentiating a bayesopt acquisition function with respect to the candidate). We wouldn't want to raise the warning every time we call backward for that purpose.

Maybe we raise a warning if (1) the user calls backward, and (2) the last set of test inputs didn't require grad OR (1) the user calls backward and (2) the test inputs were equal to the train inputs, which require grad (we already test for equality in __call__).

I think that would catch most cases (or at least more than we do now) -- basically if the test inputs require grad and are different from the train inputs, we assume that the backward was for the purpose of getting derivatives of the test inputs. Otherwise, if the test inputs don't require grad or they do but are actually the train inputs, we assume the backward call was for the hyperparameters and/or train inputs.