cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.46k stars 546 forks source link

[Docs] SVI Hadamard GP Task Dependent Input #2392

Closed eirikbaekkelund closed 11 months ago

eirikbaekkelund commented 11 months ago

Hi, Team! I am trying to implement an approximate Hadamard GP with separate inputs for each task by following the example linked below and from the issue #2043. However, there seem to be issues with the ELBO in doing so, and would like to know how this could be resolved. I have a Multitask Beta Likelihood that works well in the Multitask GP case when the input is shared across outputs, but fails when inputs are separate for each task.

Essentially, if I have T tasks, I have inputs x_t in R^(NxD) for all t=1,..T and corresponding outputs y_t in R^N for t=1,...T. These are stored as

for i, (X, y) in enumerate(loader):
    n = X.shape[0]
    dict_input['input'].append(X)
    dict_input['task_indices'].append(torch.ones(n, dtype=torch.long) * i)
    dict_input['output'].append(y)

x = torch.cat(dict_input['input'], dim=0)
task_indices = torch.cat(dict_input['task_indices'])
y = torch.cat(dict_input['output'], dim=0)

The kernel of the models has a batch_shape relative to the number of tasks, and so does the likelihood and mean. It would be something like;

covar_module = ScaleKernel(Kernel(batch_shape=num_tasks))
mean_module = ZeroMean(batch_shape=num_tasks)
likelihood = MultitaskBetaLikelihood(num_tasks=num_tasks)

where the MultitaskBetaLikelihood is defined as

class MultitaskBetaLikelihood(gpytorch.likelihoods.BetaLikelihood):
    """ 
    A multitask BetaLikelihood that supports multitask GP regression.
    """
    def __init__(
        self,
        num_tasks: int,
        scale = 15,
        batch_shape: torch.Size = torch.Size([]),
        scale_prior: Optional[Prior] = None,
        scale_constraint: Optional[Interval] = None,
    ) -> None:
        super().__init__(scale=scale)

        if scale_constraint is None:
            scale_constraint = Positive()

        self.raw_scale = torch.nn.Parameter(torch.ones(*batch_shape, 1, num_tasks))
        if scale_prior is not None:
            self.register_prior("scale_prior", scale_prior, lambda m: m.scale, lambda m, v: m._set_scale(v))

        self.register_constraint("raw_scale", scale_constraint)

    def expected_log_prob(self, observations, function_dist, *args, **kwargs):
        ret = super().expected_log_prob(observations, function_dist, *args, **kwargs)

        num_event_dim = len(function_dist.event_shape)

        if num_event_dim > 1:  # Do appropriate summation for multitask likelihood
            ret = ret.sum(list(range(-1, -num_event_dim, -1)))
        return ret

Now, the output of the forward function output = model(x, task_indices=task_indices) is of shape \sum_{t=1}^T N_t where N_t is the number of training points per task. They are usually equal such that the output will be NT when N is the number of data points. The shape of y is (NT)

The error message reads as follows:

ValueError Traceback (most recent call last) Cell In[151], line 11 9 output = model(x, task_indices=task_indices) 10 print(output.shape) ---> 11 loss = -mll(output, y, task_indices=task_indices) 12 loss.backward() 13 print('Iter %d/%d - Loss: %.3f' % (i + 1, 10, loss.item()))

File ~/opt/anaconda3/envs/gp/lib/python3.11/site-packages/gpytorch/module.py:30, in Module.call(self, *inputs, kwargs) 29 def call(self, *inputs, *kwargs): ---> 30 outputs = self.forward(inputs, kwargs) 31 if isinstance(outputs, list): 32 return [_validate_module_outputs(output) for output in outputs]

File ~/opt/anaconda3/envs/gp/lib/python3.11/site-packages/gpytorch/mlls/variational_elbo.py:77, in VariationalELBO.forward(self, variational_dist_f, target, kwargs) 63 def forward(self, variational_dist_f, target, kwargs): 64 r""" 65 Computes the Variational ELBO given :math:q(\mathbf f) and :math:\mathbf y. 66 Calling this function will call the likelihood's :meth:~gpytorch.likelihoods.Likelihood.expected_log_prob (...) 75 :return: Variational ELBO. Output shape corresponds to batch shape of the model/input data. 76 """ ---> 77 return super().forward(variational_dist_f, target, **kwargs) ... 289 format(actual_shape, expected_shape)) 290 try: 291 support = self.support

What would be the approach of fixing this issue. The documentation below and prior issues don't seem to have a clear answer to this. Any help would be highly appreciated.

This is using the following model:

class IndependentMultitaskGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, x_train, y_train, num_tasks):
        # Let's use a different set of inducing points for each task
        y_train = y_train

        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            x_train.size(-2), batch_shape=torch.Size([num_tasks])
        )

        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, x_train, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=num_tasks,
        )

        super().__init__(variational_strategy)

        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_tasks]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_tasks])),
            batch_shape=torch.Size([num_tasks])
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)

        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

And training by doing:

model.train()
likelihood.train()

mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=y.size(0))
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

for i in range(10):
    optimizer.zero_grad()
    output = model(x, task_indices=task_indices)
    print(output.shape)
    loss = -mll(output, y, task_indices=task_indices)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f' % (i + 1, 10, loss.item()))
    optimizer.step()

This example doesn't seem to work for that instance unless I am doing something wrong... https://docs.gpytorch.ai/en/latest/examples/04_Variational_and_Approximate_GPs/SVGP_Multitask_GP_Regression.html#Output-modes

eirikbaekkelund commented 11 months ago

For anyone encountering the same issue, a quick fix is to handle the task_indices as the kwargs argument to the super() call on the forward function in the Variational ELBO by having to following likelihood. The training should then be straight forward

class HadamardBetaLikelihood(MultitaskBetaLikelihood):
    def forward(self, function_samples, *args, **kwargs):
        assert 'task_indices' in kwargs.keys(), 'task_indices must be passed as a keyword argument'
        mixture = torch.distributions.Normal(0, 1).cdf(function_samples)

        task_indices = kwargs['task_indices']

        if self.scale.shape[-1]> 1:
            alpha_mask = torch.zeros_like(mixture)
            beta_mask = torch.zeros_like(mixture)
            # can be vectorised but doing this for readability
            for idx in torch.unique(task_indices):
                alpha_mask[:,task_indices == idx] = self.scale[:,idx] * mixture[:,task_indices == idx]
                beta_mask[:,task_indices == idx] = self.scale[:,idx] - alpha_mask[:,task_indices == idx]

            self.alpha = alpha_mask
            self.beta = beta_mask
        else:
            self.alpha = self.scale * mixture
            self.beta = self.scale - self.alpha

        self.alpha = torch.clamp(self.alpha, 1e-10, 1e10)
        self.beta = torch.clamp(self.beta, 1e-10, 1e10)

        return base_distributions.Beta(concentration1=self.alpha, concentration0=self.beta)

    def expected_log_prob(self, observations, function_dist, *args, **kwargs):
        log_prob_lambda = lambda function_samples: self.forward(function_samples, *args, **kwargs).log_prob(observations)
        ret = self.quadrature(log_prob_lambda, function_dist)

        num_event_dim = len(function_dist.event_shape)

        if num_event_dim > 1:  # Do appropriate summation for multitask likelihood
            ret = ret.sum(list(range(-1, -num_event_dim, -1)))
        return ret
eirikbaekkelund commented 11 months ago

Issue resolved :)