Open SaiAakash opened 1 day ago
A potential fix is to route the code inside BatchedMultiOutputGPyTorchModel
's posterior method in different ways.
When the trace_mode is on it could be just
if self._num_outputs > 1:
mvn = MultitaskMultivariateNormal.from_batch_mvn(mvn, task_dim=0)
otherwise it would just do what exists now.
This will just deal with the evaluation of the lazy_covariance_matrix in trace_mode and fix this error and not break the tests.
A potential solution has been introduced in #2592.
Hmm interesting. At least in the past there must have been some reason to not just use MultitaskMultivariateNormal.from_batch_mvn()
here. I fail to recall why exactly without digging into this more - possibly b/c the BlockInterleavedLinearOperator
used here may result in issues (possibly performance related, possibly related to incompatibilities of that operator with other operations downstream)?
I have put in a PR that provides a temporary fix which at least enables exporting the model to torchscript. This uses the from_batch_mvn
operation only when the posterior method is called in the trace_mode. Please let me know if there is a better fix for this issue.
🐛 Bug
Exporting SingleTask GPs (probably applies to other GPs as well) trained on Multi-Output data to torchscript fails.
To reproduce
Stack trace/error message
Expected Behavior
Should have converted the model predictor into a torchscript module.
System information
Please complete the following information:
Additional context
This is a bug in the
posterior
method ofBatchedMultiOutputGPyTorchModel
class. The generation of the list ofMultivariateNormal
distributions for every output is a bit complicated and the internal construction of aMultiTaskMultivariateNormal
from these distributions fails when we try to evaluate the lazy_covariance_matrix in the trace_mode. Instead of doing this jugglery, we could simply usefrom_batch_mvn
functionality and specify thetask_dim
parameter to directly create the requiredMultiTaskMultivariateNormal
posterior distribution when in trace_mode. I was able to fix the error specified above by doing the same.