pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3k stars 383 forks source link

`ModelListGP.fantasize` does not support `MultiTaskGP`s, but other `ModelListGP` methods do #2398

Open esantorella opened 1 week ago

esantorella commented 1 week ago

Discussed in https://github.com/pytorch/botorch/discussions/2396

Originally posted by **lucky-luke-98** June 26, 2024 [....] My idea was to separate the creation of each correlation group and treat them based on having correlation to other targets (-> `MultiTaskGP`) or not having correlation to other targets (-> `SingleTaskGP`). Afterwards, to create one unified model, I would append them in a `ModelListGP`. [....] following my code, I run into problems with this procedure when performing a Multi-Step-Lookahead (qMultiStepLookahead). Here, the [`fantasize` method](https://botorch.org/api/_modules/botorch/models/model.html#ModelList.fantasize) of the `ModelListGP` assumes that the number of outputs (self.num_outputs) is equal to the number of models in the ModelList. In my case this assumption is not true.
lucky-luke-98 commented 5 days ago

In the following, I present a reproducible example for this issue:

Versions used:


Reproducible Code

I am using a function to create the groups of targets that are correlated. Its ouput can be seen by the assertion.


# imports

import torch
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.model_list_gp_regression import ModelListGP

from gpytorch.kernels import ScaleKernel, RBFKernel
from gpytorch.mlls import SumMarginalLogLikelihood

from botorch.fit import fit_gpytorch_mll
from botorch.sampling.normal import SobolQMCNormalSampler

### data creation ###

torch.manual_seed(42)
n = 500 
train_X = torch.randn(n, 12)

train_Y1 = torch.randn(n, 1)
train_Y2 = 0.8 * train_Y1 + torch.randn(n, 1) * 0.5
train_Y3 = torch.randn(n, 1)
train_Y = torch.cat([train_Y1, train_Y2, train_Y3], dim=1)

noise_targets = torch.tensor([5e-3, 1e-2, 4e-4])

# computation of correlation groups:
corr_groups = _get_correlated_groups(train_Y, threshold=0.8)
assert corr_groups == [[0,1], [2]], f"Unexpected group creation, got {corr_groups}."

### GP creation ###

models = []

# create models for each correlation group
for group in corr_groups:
    if len(group) > 1:
        # create MultiTaskGP with num_task for all indices out of corr_group
        group_indices = torch.tensor(group)
        num_tasks = len(group_indices)
        task_indices = torch.arange(num_tasks).repeat(train_X.size(0), 1).view(-1, 1)

        train_x_group = torch.cat([train_X.repeat(1, num_tasks).view(-1, train_X.size(-1)), task_indices], dim=-1)
        train_y_group = train_Y[:, group_indices].view(-1, 1)

        train_yvar = torch.stack([torch.full_like(train_Y[:, i], noise_targets[group_indices[i]]) for i in range(len(group))], 1).view(-1, 1)

        model = MultiTaskGP(
            train_X=train_x_group,
            train_Y=train_y_group,
            task_feature=-1,
            train_Yvar=train_yvar,
        )       
        models.append(model)

    else:
        # create SingleTaskGP for targets that have no correlation
        ind = group[0]

        train_y_group = train_Y[..., ind:ind+1]

        noise = noise_targets[ind]
        train_yvar = torch.full_like(train_y_group, noise)

        model = SingleTaskGP(
            train_X=train_X,
            train_Y=train_y_group,
            train_Yvar=train_yvar,
            covar_module=ScaleKernel(RBFKernel()),
        )
        models.append(model)

# combine all models
overall_model = ModelListGP(*models)
mll = SumMarginalLogLikelihood(overall_model.likelihood, overall_model)

# fit the model using mll
try:
    fit_gpytorch_mll(mll, max_attempts=30)
except Exception as e:
    raise Exception(f"Fitting did not work. Exception: {e}")

### create fantasize model ###

sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1]))
fant_model = overall_model.fantasize(train_X, sampler=sampler)

Keep in mind:

In general, I am not calling the fantasize method by itself. I am utilizing the botorch.acquisition.qMultiStepLookahead class (with the created with the ModelListGP) which is passed to the botorch.optim.optimize.optimize_acqf function. This itself calls the fantasize method of the provided model.

Error:

IndexError                                Traceback (most recent call last)
     [85] ### create fantasize model ###
     [87] sampler = SobolQMCNormalSampler(sample_shape=torch.Size([1]))
---> [88] fant_model = overall_model.fantasize(train_X, sampler=sampler)

File c:\Users\venv\lib\site-packages\botorch\models\model.py:685, in ModelList.fantasize(self, X, sampler, observation_noise, evaluation_mask, **kwargs)
    [680]     else:
    [681]         sampler_i = (
    [682]             sampler.samplers[i] if isinstance(sampler, ListSampler) else sampler
    [683]         )
--> [685]     fant_model = self.models[i].fantasize(
    [686]         X=X_i,
    [687]         sampler=sampler_i,
    [688]         observation_noise=observation_noise_i,
    [689]         **kwargs,
    [690]     )
    [691]     fant_models.append(fant_model)
    [692] return self.__class__(*fant_models)
File c:\Users\venv\lib\site-packages\botorch\models\model.py:384, in FantasizeMixin.fantasize(self, X, sampler, observation_noise, **kwargs)
    [379] elif observation_noise is None and isinstance(
    [380]     self.likelihood, FixedNoiseGaussianLikelihood
    [381] ):
    [382]     if self.num_outputs > 1:
    [383]         # make noise ... x n x m
--> [384]         observation_noise = self.likelihood.noise.transpose(-1, -2)
    [385]     else:
    [386]         observation_noise = self.likelihood.noise.unsqueeze(-1)

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got -2)

To me it seems that the fantasize method should create a fantasize model for each model of the ModelList. However, since self.num_outputs will return 3, the number of models (=2) and num_outputs do not match which creates the error. The traceback also shows that there seems to be an issue with the created FixedNoiseGaussianLikelihood from GPYTorch.

Expected Behaviour

I would expect the fantasize-model to be created without an error.


I hope this code snippet helps to determine the error of my code or the potential bug.