cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.58k stars 562 forks source link

Independent GP Inputs #1899

Open finnBsch opened 2 years ago

finnBsch commented 2 years ago

Hi, I currently use a Vartiational Independent GP Model as proposed here: https://docs.gpytorch.ai/en/v1.5.1/examples/04_Variational_and_Approximate_GPs/SVGP_Multitask_GP_Regression.html.

I have 2 questions regarding this. I have 10 inputs x each of size N. I want to feed these samples into the GP and sample from the resulting distribution. However, I basically want to sample from the distribution resulting from each independent sample only (NOT the joint). To speed things up, I concatenated these inputs into a matrix of size 10 x N . If I am not mistaken, this will give me the joint distribution rather than 10 independent distributions. My question is, how can I get 10 independent distributions from 10 concatenated inputs?

Secondly, for training, I assume that when passing in the training data in batches, this happens as well. Does that then change the outcome? Every input I have is completely independent from the others, so I never really want to deal with the joint.

Balandat commented 2 years ago

You should be able to batch-evaluate this by concatenating this along a batch dimension. Say d is the input dimension, then IIUC what you did was pass 10N x d, but in order to batch evaluate you will want to pass this as 10 x N x d instead.

finnBsch commented 2 years ago

Hm, maybe my explanation was a bit misleading. My input dimension is N. I want to feed 10 independent samples (so batch size 1?) into the GP. 10 x 1 x N doesn't work, RuntimeError: Shapes are not broadcastable for mul operation. What did I do wrong? thanks in advance

Balandat commented 2 years ago

If N is the input dimension, then yes, 10 x 1 x N is the right shape to pass to the model for batch evaluation. Though I'm not sure the model defined in the docs that you linked properly supports this kind of batch evaluation. Could you post a fully reproducible example so we can investigate?

finnBsch commented 2 years ago

I'll post a fully reporducible example later, don't have access to a machine right now. Here's the currently used model, if that helps

class IndependentMultitaskGPModelApproximate(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points_num, input_dim, num_tasks):
        # Let's use a different set of inducing points for each task
        inducing_points = torch.rand(num_tasks, inducing_points_num, input_dim)

        # We have to mark the CholeskyVariationalDistribution as batch
        # so that we learn a variational distribution for each task
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([num_tasks])
        )

        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ),
            num_tasks=num_tasks,
        )

        super().__init__(variational_strategy)

        # The mean and covariance modules should be marked as batch
        # so we learn a different set of hyperparameters
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_tasks]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.MaternKernel(nu=1.5, batch_shape=torch.Size([num_tasks])),
            batch_shape=torch.Size([num_tasks])
            # gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_tasks],
            #                                                   ard_num_dims=input_dim)),
        )

    def forward(self, x):
        # The forward function should be written as if we were dealing with each output
        # dimension in batch
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)