cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 557 forks source link

[Bug] #2370

Closed IlterOnatKorkmaz closed 1 year ago

IlterOnatKorkmaz commented 1 year ago

🐛 Bug

Hello GPyTorch team,

I am currently working on a project that requires using Gaussian processes to handle multi-task regression. I am using your ApproximateGP model with IndependentMultitaskVariationalStrategy, as shown in your SVGP Multitask GP Regression example. However, I've run into an issue when trying to update the model with new training data.

My use-case necessitates setting new training data for the Gaussian process to condition on, without refitting the model parameters. I understand that ExactGP models have the set_train_data method which suits this purpose perfectly, but ApproximateGP does not seem to have a similar functionality. To work around this, I attempted to use the get_fantasy_model method, hoping to generate a new model conditioned on both the existing and new data.

Unfortunately, when I attempted to use get_fantasy_model with my ApproximateGP model (initialized with IndependentMultitaskVariationalStrategy), I encountered an AttributeError indicating that the IndependentMultitaskVariationalStrategy object does not have an attribute name. It appears that get_fantasy_model is not fully compatible with IndependentMultitaskVariationalStrategy, or perhaps with ApproximateGP models in general.

Given this situation, I'd appreciate if you could provide some guidance or potential workarounds. Furthermore, would you consider adding a functionality similar to set_train_data for ApproximateGP models, or possibly extending the compatibility of get_fantasy_model to handle ApproximateGP and IndependentMultitaskVariationalStrategy? I believe these additions would make GPyTorch more flexible in handling a wider range of use-cases involving variational and multitask GPs.

Thanks for your time and your continuous work on this great library.

Best regards, İlter Onat Korkmaz

To reproduce

Code snippet to reproduce

import torch
import gpytorch

class IndependentMultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, num_tasks):
        inducing_points = torch.rand(num_tasks, 16, 1)
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_tasks])
        )
        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=num_tasks,
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_tasks]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_tasks])),
            batch_shape=torch.Size([num_tasks])
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

num_tasks = 4
model = IndependentMultitaskGPModel(num_tasks)
x_new = torch.rand(num_tasks, 10, 1)
y_new = torch.rand(num_tasks, 10)
fantasy_model = model.get_fantasy_model(x_new, y_new)  # Raises AttributeError

Stack trace/error message

---------------------------------------------------------------------------

Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
File ~/env/lib/python3.10/site-packages/gpytorch/module.py:437, in Module.__getattr__(self, name)
    436 try:
--> 437     return super().__getattribute__(name)
    438 except AttributeError:

AttributeError: 'IndependentMultitaskVariationalStrategy' object has no attribute '__name__'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
Cell In[1], line 32
     30 x_new = torch.rand(num_tasks, 10, 1)
     31 y_new = torch.rand(num_tasks, 10)
---> 32 fantasy_model = model.get_fantasy_model(x_new, y_new)  # Raises AttributeError

File ~/env/lib/python3.10/site-packages/gpytorch/models/approximate_gp.py:103, in ApproximateGP.get_fantasy_model(self, inputs, targets, **kwargs)
     78 def get_fantasy_model(self, inputs, targets, **kwargs):
     79     r"""
     80     Returns a new GP model that incorporates the specified inputs and targets as new training data using
     81     online variational conditioning (OVC).
   (...)
    101 
    102     """
...
   1613         return modules[name]
-> 1614 raise AttributeError("'{}' object has no attribute '{}'".format(
   1615     type(self).__name__, name))

AttributeError: 'IndependentMultitaskVariationalStrategy' object has no attribute '__name__'

Expected Behavior

The code returns an ExactGP model as the fantasy model.

System information

Please complete the following information:

Additional context

None

gpleiss commented 1 year ago

There was a small typo in our error handling (I'm putting up a PR now). This is the error that you should have seen:

NotImplementedError: No fantasy model support for IndependentMultitaskVariationalStrategy. Only VariationalStrategy and UnwhitenedVariationalStrategy are currently supported.

My use-case necessitates setting new training data for the Gaussian process to condition on, without refitting the model parameters. I understand that ExactGP models have the set_train_data method which suits this purpose perfectly, but ApproximateGP does not seem to have a similar functionality. To work around this, I attempted to use the get_fantasy_model method, hoping to generate a new model conditioned on both the existing and new data.

Currently we don't support this for IndependentMultitaskVariationalStrategy, though in theory it shouldn't be too hard to set up (it would essentially just require calling through to the base variational strategy). We'd accept a PR if you'd be up to implementing it!