pytorch / botorch

Bayesian optimization in PyTorch
https://botorch.org/
MIT License
3.09k stars 398 forks source link

[Bug] Unable to export Multi-Output SingleTaskGP to Torchscript #2591

Open SaiAakash opened 1 day ago

SaiAakash commented 1 day ago

🐛 Bug

Exporting SingleTask GPs (probably applies to other GPs as well) trained on Multi-Output data to torchscript fails.

To reproduce

import numpy as np
import torch
from torch import Tensor
from botorch.models import SingleTaskGP
from botorch.fit import fit_gpytorch_mll
from botorch.models.transforms import Standardize, Normalize
from gpytorch.mlls import ExactMarginalLogLikelihood
from gpytorch.settings import trace_mode

X = torch.linspace(0, 1, 100).view(-1, 1)
y1 = torch.sin(2 * np.pi * X) + torch.randn_like(X) * 0.2
y2 = torch.cos(2 * np.pi * X) + torch.randn_like(X) * 0.2
Y = torch.cat([y1, y2], dim=1)

input_transform = Normalize(d=1)
outcome_transform = Standardize(m=2)
gp = SingleTaskGP(
    X,
    Y,
    outcome_transform=outcome_transform,
    input_transform=input_transform,
)

mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
fit_gpytorch_mll(mll)

class MeanVarModelWrapper(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        # get the model posterior
        posterior = self.model.posterior(x, observation_noise=True)
        mean = posterior.mean.detach()
        std = posterior.variance.sqrt().detach()
        return mean, std

X_test = torch.rand(10).view(-1, 1)
wrapped_model = MeanVarModelWrapper(gp)

with torch.no_grad(), trace_mode():
    wrapped_model(X_test)  # Compute caches
    traced_model = torch.jit.trace(wrapped_model, X_test)

Stack trace/error message

{
    "name": "RuntimeError",
    "message": "mean shape torch.Size([10, 2]) is incompatible with covariance shape torch.Size([160, 160])",
    "stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[4], line 6
      4 with torch.no_grad(), trace_mode():
      5     wrapped_model(X_test)  # Compute caches
----> 6     traced_model = torch.jit.trace(wrapped_model, X_test)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:1000, in trace(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
    993 from torch._utils_internal import (
    994     check_if_torch_exportable,
    995     log_torch_jit_trace_exportability,
    996     log_torchscript_usage,
    997 )
    999 log_torchscript_usage(\"trace\")
-> 1000 traced_func = _trace_impl(
   1001     func,
   1002     example_inputs,
   1003     optimize,
   1004     check_trace,
   1005     check_inputs,
   1006     check_tolerance,
   1007     strict,
   1008     _force_outplace,
   1009     _module_class,
   1010     _compilation_unit,
   1011     example_kwarg_inputs,
   1012     _store_inputs,
   1013 )
   1015 if check_if_torch_exportable():
   1016     from torch._export.converter import TS2EPConverter

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:695, in _trace_impl(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_kwarg_inputs, _store_inputs)
    693         else:
    694             raise RuntimeError(\"example_kwarg_inputs should be a dict\")
--> 695     return trace_module(
    696         func,
    697         {\"forward\": example_inputs},
    698         None,
    699         check_trace,
    700         wrap_check_inputs(check_inputs),
    701         check_tolerance,
    702         strict,
    703         _force_outplace,
    704         _module_class,
    705         example_inputs_is_kwarg=isinstance(example_kwarg_inputs, dict),
    706         _store_inputs=_store_inputs,
    707     )
    708 if (
    709     hasattr(func, \"__self__\")
    710     and isinstance(func.__self__, torch.nn.Module)
    711     and func.__name__ == \"forward\"
    712 ):
    713     if example_inputs is None:

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/jit/_trace.py:1275, in trace_module(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit, example_inputs_is_kwarg, _store_inputs)
   1273 else:
   1274     example_inputs = make_tuple(example_inputs)
-> 1275     module._c._create_method_from_trace(
   1276         method_name,
   1277         func,
   1278         example_inputs,
   1279         var_lookup_fn,
   1280         strict,
   1281         _force_outplace,
   1282         argument_names,
   1283         _store_inputs,
   1284     )
   1286 check_trace_method = module._c._get_method(method_name)
   1288 # Check the trace against new traces created from user-specified inputs

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/torch/nn/modules/module.py:1543, in Module._slow_forward(self, *input, **kwargs)
   1541         recording_scopes = False
   1542 try:
-> 1543     result = self.forward(*input, **kwargs)
   1544 finally:
   1545     if recording_scopes:

Cell In[3], line 8, in MeanVarModelWrapper.forward(self, x)
      6 def forward(self, x):
      7     # get the model posterior
----> 8     posterior = self.model.posterior(x, observation_noise=True)
      9     mean = posterior.mean.detach()
     10     std = posterior.variance.sqrt().detach()

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/botorch/models/gpytorch.py:459, in BatchedMultiOutputGPyTorchModel.posterior(self, X, output_indices, observation_noise, posterior_transform)
    451         output_indices = output_indices or range(self._num_outputs)
    452         mvns = [
    453             MultivariateNormal(
    454                 mean_x.select(dim=output_dim_idx, index=t),
   (...)
    457             for t in output_indices
    458         ]
--> 459         mvn = MultitaskMultivariateNormal.from_independent_mvns(mvns=mvns)
    460         # mvn = MultitaskMultivariateNormal.from_batch_mvn(mvn)
    462 posterior = GPyTorchPosterior(distribution=mvn)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/distributions/multitask_multivariate_normal.py:199, in MultitaskMultivariateNormal.from_independent_mvns(cls, mvns)
    193 covar_blocks_lazy = CatLinearOperator(
    194     *[mvn.lazy_covariance_matrix.unsqueeze(0) for mvn in mvns],
    195     dim=0,
    196     output_device=mean.device,
    197 )
    198 covar_lazy = BlockDiagLinearOperator(covar_blocks_lazy, block_dim=0)
--> 199 return cls(mean=mean, covariance_matrix=covar_lazy, interleaved=False)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/pyro/distributions/distribution.py:26, in DistributionMeta.__call__(cls, *args, **kwargs)
     24     if result is not None:
     25         return result
---> 26 return super().__call__(*args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/twinlab-models-OKGmZGkp-py3.11/lib/python3.11/site-packages/gpytorch/distributions/multitask_multivariate_normal.py:71, in MultitaskMultivariateNormal.__init__(self, mean, covariance_matrix, validate_args, interleaved)
     65         mean = mean.expand(
     66             *batch_shape,
     67             mean.size(-2),
     68             covariance_matrix.size(-2) // mean.size(-2),
     69         )
     70     else:
---> 71         raise RuntimeError(
     72             f\"mean shape {mean.shape} is incompatible with covariance shape {covariance_matrix.shape}\"
     73         )
     74 else:
     75     mean = mean.expand(*batch_shape, *mean.shape[-2:])

RuntimeError: mean shape torch.Size([10, 2]) is incompatible with covariance shape torch.Size([160, 160])"
}

Expected Behavior

Should have converted the model predictor into a torchscript module.

System information

Please complete the following information:

Additional context

This is a bug in the posterior method of BatchedMultiOutputGPyTorchModel class. The generation of the list of MultivariateNormal distributions for every output is a bit complicated and the internal construction of a MultiTaskMultivariateNormal from these distributions fails when we try to evaluate the lazy_covariance_matrix in the trace_mode. Instead of doing this jugglery, we could simply use from_batch_mvn functionality and specify the task_dim parameter to directly create the required MultiTaskMultivariateNormal posterior distribution when in trace_mode. I was able to fix the error specified above by doing the same.

SaiAakash commented 1 day ago

A potential fix is to route the code inside BatchedMultiOutputGPyTorchModel's posterior method in different ways.

When the trace_mode is on it could be just

if self._num_outputs > 1:
                mvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=0)

otherwise it would just do what exists now.

This will just deal with the evaluation of the lazy_covariance_matrix in trace_mode and fix this error and not break the tests.

A potential solution has been introduced in #2592.

Balandat commented 21 hours ago

Hmm interesting. At least in the past there must have been some reason to not just use MultitaskMultivariateNormal.from_batch_mvn() here. I fail to recall why exactly without digging into this more - possibly b/c the BlockInterleavedLinearOperator used here may result in issues (possibly performance related, possibly related to incompatibilities of that operator with other operations downstream)?

SaiAakash commented 21 hours ago

I have put in a PR that provides a temporary fix which at least enables exporting the model to torchscript. This uses the from_batch_mvn operation only when the posterior method is called in the trace_mode. Please let me know if there is a better fix for this issue.