Closed eirikbaekkelund closed 11 months ago
As the implementation of VariationalELBO._log_likelihood_term
(Code) expects the last dimension of likelihood.expected_log_prob
to be the data/batch dimension, it seems like that this is due to a difference in implementation of the likelihood.expected_log_prob
method from MultitaskGaussianLikelihood
and BetaLikelihood
.
In the source code for MultitaskGaussianLikelihood
, it actually detects that the GP output is multi-task and sums the task dimensions. Meanwhile, as for _OneDimensionalLikelihood.expected_log_prob
, there is no additional adjustment which means that the return will have shapes like (n_data, n_tasks)
, breaks VariationalELBO._log_likelihood_term
expectation that the last dimension is the data dimension.
The output from the beta likelihood is of the form (n_samples, n_points, n_tasks), so we have e.g. an output of (10,100,5).
As @spectraldani mentions, the beta likelihood is not designed for multitask GPs. The expected shape is of the form n_samples, <batch_dimension>, n_points
. It is a OneDimensionalLikelihood
And would a new likelihood for the multitask beta need to be implemented?
It would. I'm not sure we'd be open to it as a PR right now, unless there's potentially a way to nicely generalize all of our one dimensional likelihoods into multitask likelihoods.
This does the job
class MultitaskBetaLikelihood(gpytorch.likelihoods.BetaLikelihood):
"""
A multitask BetaLikelihood that supports multitask GP regression.
"""
def __init__(
self,
n_tasks: int,
batch_shape: torch.Size = torch.Size([]),
scale_prior: Optional[Prior] = None,
scale_constraint: Optional[Interval] = None,
) -> None:
super().__init__(scale)
if scale_constraint is None:
scale_constraint = Positive()
self.raw_scale = torch.nn.Parameter(torch.ones(*batch_shape, 1, n_tasks))
if scale_prior is not None:
self.register_prior("scale_prior", scale_prior, lambda m: m.scale, lambda m, v: m._set_scale(v))
self.register_constraint("raw_scale", scale_constraint)
def expected_log_prob(self, observations, function_dist, *args, **kwargs):
ret = super().expected_log_prob(observations, function_dist, *args, **kwargs)
num_event_dim = len(function_dist.event_shape)
if num_event_dim > 1: # Do appropriate summation for multitask likelihood
ret = ret.sum(list(range(-1, -num_event_dim, -1)))
return ret
🐛 Bug - Beta Likelihood for SVI Multitask GP
To reproduce
When fitting an approximate GP using SVI, I have a model working with the multitask Gaussian Likelihood, however, when switching to a Beta Likelihood, the optimisation routine does not seem to learn much. The scale parameter should work fine for all the tasks passed, so my intuition is that there is a need for the batch shape to be accounted for that it currently doesn't seem to do. The output from the beta likelihood is of the form (n_samples, n_points, n_tasks), so we have e.g. an output of (10,100,5). So, how does this information get stored of sample dim = 0, batch_dim = 1, and event dim = 2 get passed appropriately to the variational elbo to be minimized? I.e., how does the VI class consider event shape / batch shape / sample shape? And would a new likelihood for the multitask beta need to be implemented?
Code snippet to reproduce
Optimisation tracking for each
Expected Behavior
Training these on a simple approximate GP with the same settings yield good fits but does not carry over to the multitask case. I couldn't find any post dealing with the same issue, so was wondering what causes this behaviour and possible fixes.
System information
Additional context
Resulting image from fitting with MultitaskGaussianLikelihood:
Resulting image from fitting with Beta Likelihood:![90d73f1e-ae1f-4dd6-b34a-c07e40ddb17a](https://github.com/cornellius-gp/gpytorch/assets/89584092/7ce8e135-ad62-4776-91b7-c6b7e38932df)