pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.11k stars 406 forks source link

[Bug] `KroneckerMultiTaskGP` incompatible with `batch_cross_validation` or batched models #2553

Closed slishak-PX closed 1 week ago

slishak-PX commented 2 months ago

🐛 Bug

When trying to use batch_cross_validation to fit a KroneckerMultiTaskGP, a batch shape error occurs.

The code snippet below uses a similar example to https://botorch.org/tutorials/batch_mode_cross_validation.

The following snippet also causes the same error, suggesting that the problem is that KroneckerMultiTaskGP does not correctly support batching.

gp = KroneckerMultiTaskGP(cv_folds.train_X, cv_folds.train_Y)
gp.posterior(cv_folds.test_X)

To reproduce

Code snippet to reproduce

import math
import torch
from botorch.cross_validation import batch_cross_validation, gen_loo_cv_folds
from botorch.models import SingleTaskGP, KroneckerMultiTaskGP
from botorch.models.transforms.input import Normalize
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood

device = torch.device("cuda:0")
dtype = torch.float64
torch.manual_seed(3)

sigma = math.sqrt(0.2)
coefs = torch.linspace(1, 4, 4, dtype=dtype, device=device).view(1, -1)
train_X = torch.linspace(0, 1, 20, dtype=dtype, device=device).view(-1, 1)
train_Y_noiseless = torch.sin(train_X * coefs * math.pi)
train_Y = train_Y_noiseless + sigma * torch.randn_like(train_Y_noiseless)

cv_folds = gen_loo_cv_folds(train_X, train_Y)

# If this is False, a SingleTaskGP is used
multi_task = True

cv_results = batch_cross_validation(
    model_cls=KroneckerMultiTaskGP if multi_task else SingleTaskGP,
    mll_cls=ExactMarginalLogLikelihood,
    cv_folds=cv_folds,
    model_init_kwargs={
        "input_transform": Normalize(d=train_X.shape[-1]),
    },
)

Stack trace/error message

File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/kernels/multitask_kernel.py:53, in MultitaskKernel.forward(self, x1, x2, diag, last_dim_is_batch, **params)
     51     covar_i = covar_i.repeat(*x1.shape[:-2], 1, 1)
     52 covar_x = to_linear_operator(self.data_covar_module.forward(x1, x2, **params))
---> 53 res = KroneckerProductLinearOperator(covar_x, covar_i)
     54 return res.diagonal(dim1=-1, dim2=-2) if diag else res

File /opt/conda/envs/python3/lib/python3.10/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File /opt/conda/envs/python3/lib/python3.10/site-packages/linear_operator/operators/kronecker_product_linear_operator.py:81, in KroneckerProductLinearOperator.__init__(self, *linear_ops)
     79     batch_broadcast_shape = torch.broadcast_shapes(*(linear_op.batch_shape for linear_op in linear_ops))
     80 except RuntimeError:
---> 81     raise RuntimeError(
     82         "Batch shapes of LinearOperators "
     83         f"({', '.join([str(tuple(linear_op.shape)) for linear_op in linear_ops])}) "
     84         "are incompatible for a Kronecker product."
     85     )
     87 if len(batch_broadcast_shape):  # Otherwise all linear_ops are non-batch, and we don't need to expand
     88     # NOTE: we must explicitly call requires_grad on each of these arguments
     89     # for the automatic _bilinear_derivative to work in torch.autograd.Functions
     90     linear_ops = tuple(
     91         linear_op._expand_batch(batch_broadcast_shape).requires_grad_(linear_op.requires_grad)
     92         for linear_op in linear_ops
     93     )

RuntimeError: Batch shapes of LinearOperators ((20, 19, 19), (400, 4, 4)) are incompatible for a Kronecker product.

Expected Behavior

System information

Please complete the following information:

Additional context

Add any other context about the problem here.

sdaulton commented 2 months ago

Hi @slishak-PX,

Thanks for flagging this. The docstring for batch_cross_validation says that MultiTaskGPs are not supported, but we should raise an UnsupportedError too. I'll put up a PR to add that.

The following snippet also causes the same error, suggesting that the problem is that KroneckerMultiTaskGP does not correctly support batching.

Interesting. Have you looked into this any deeper? If you have an idea of the fix, I'd be happy to review a PR. Also curious if @jandylin has any idea

slishak-PX commented 2 months ago

Thanks for flagging this. The docstring for batch_cross_validation says that MultiTaskGPs are not supported, but we should raise an UnsupportedError too. I'll put up a PR to add that.

Sorry, I should have read more carefully - I think my attention was drawn by the "WARNING" at the top and then I forgot to read the rest!

Interesting. Have you looked into this any deeper? If you have an idea of the fix, I'd be happy to review a PR. Also curious if @jandylin has any idea

I haven't investigated more deeply yet, but I've tried evaluating the posterior with inputs of multiple batch dimensions, and the same error appears.