pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.49k stars 982 forks source link

Multi-output multi-lengthscale GPs [feature request] [discussion] #2889

Open davindicode opened 3 years ago

davindicode commented 3 years ago

Hey everyone,

I have coded up a multi-output multi-lengthscale GP (i.e. separate kernel hyperparameter per input dimension and per output dimension). It is pretty efficient, and scalable combined with inducing points. I took the code from GP contrib and modified it for research purposes (neural data analysis for big data), but if there is enough interest here it should be pretty easy to put this into Pyro itself. The question I have is, how should I package it in? Replace the current GP, or leave the option to use simple GP versus multi-dimension GP, or provide a separate subpackage in addition to contrib.GP?

I am aware of the interfacing with GPyTorch. Initially I tried to use that, but it is not straightforward to perform all PPL features (the loss objective is limited in this interface, GPLVM is tricky etc.).

fritzo commented 3 years ago

cc @fehiepsi

fehiepsi commented 3 years ago

Hi @davindicode, if there's a lot of changes, it would be nice if you can start with a design doc so we can figure out how many changes we need for this feature. If you think there would be many changes, then having a separate module would be a better solution. Otherwise, if we just need to introduce 1 extra argument to the kernels/models to handle batch dimensions, then integrating with contrib.gp would be a better solution. I still don't know what is the API that you have in mind so it is hard to tell. Personally, I think that introducing a new notion batch_shape for the kernels would make things easier to integrate. Pls let me know what you have in mind.

davindicode commented 3 years ago

Hi thanks for your suggestions! The code modifications are not that big but the usage will change slightly. What I have coded up requires the input to be higher dimensional. I wrote it for neuroscience applications, but what happens is that now we essentially have a problem where we want the GP to map from d dimensional X (tensor of shape (batch_shape x d)) to N dimensional output Y (tensor of shape (batch_shape x N). We essentially put a separate lengthscale parameter on each dimension of X, and for each output dimension. Using broadcasting, the user could specify the parameters to be tied across some dimensions by setting some dimensions of the lengthscale to be of size 1. Internally, we also keep track of a sample dimension, which should become sample_shape for general use.

Some potential issues:

See the code here that I wrote for the Kernel module (note time is essentially batch shape, neurons are output dimensions)


class Kernel(torch.nn.Module):
    r"""
    Base class for multi-lengthscale kernels used in Gaussian Processes.
    Inspired by pyro GP kernels.
    """
    def __init__(self, input_dim, track_dims=None, f="exp"):
        super().__init__()

        if track_dims is None:
            track_dims = list(range(input_dim))
        elif input_dim != len(track_dims):
            raise ValueError("Input size and the length of active dimensionals should be equal.")
        self.input_dim = input_dim
        self.track_dims = track_dims

        if f == 'exp':
            self.f = lambda x : torch.exp(x)
            self.f_inv = lambda x: torch.log(x)
        elif f == 'softplus':
            self.f = lambda x : F.softplus(x)
            self.f_inv = lambda x: torch.where(x > 30, x, torch.log(torch.exp(x) - 1))
        elif f == 'relu':
            self.f = lambda x : torch.clamp(x, min=0)
            self.f_inv = lambda x: x
        else:
            raise NotImplementedError("Link function is not supported.")

    def forward(self, X, Z=None, diag=False):
        r"""
        Calculates covariance matrix of inputs on active dimensionals.
        """
        raise NotImplementedError

    def _slice_input(self, X):
        r"""
        Slices :math:`X` according to ``self.track_dims``.
        :param torch.tensor X: input with shape (neurons, samples, timesteps, dimensions)
        :returns: a 2D slice of :math:`X`
        :rtype: torch.tensor
        """
        if X.dim() == 4:
            return X[..., self.track_dims]
        else:
            raise ValueError("Input X must be of shape (N x K x T x D).")

    def _XZ(self, X, Z=None):
        """
        Slice the input of :math:`X` and :math:`Z` into correct shapes.
        """
        if Z is None:
            Z = X

        X = self._slice_input(X)
        Z = self._slice_input(Z)
        if X.size(-1) != Z.size(-1):
            raise ValueError("Inputs must have the same number of features.")

        return X, Z`
fehiepsi commented 3 years ago

I think we'll want to keep the current API usage but add more functionalities. I'm not sure if I can follow your intention of using GP module as well as potential issues. For now, how about we restrict the discussion scope to the kernel and its math (and skip the interaction between input/output of GP)?

Assume that X has the shape (b, N, d) where b is the batch dimension, N is the number of data points, and d is the number of feature dimensions. By default, the parameters lengthscale, variance are scalars. What you want is to allow lengthscale has shape (c, b, d) (or (b, d), (1, b, d), (c, 1, d)) so that kernel(X, Z) will have shape (c, b, N, N)? Currently, this kind of broadcasting is achieved through the forward method of each kernel.

davindicode commented 3 years ago

Yes I agree with what you wrote there. So in practice, I can have a go at modifying the kernel module. However, this also involves modifying the conditional Gaussian function in another file (since we have higher dimensional kernel matrices now). Also, X has shape (b, N, d) but with b being the MC sample dimension right? (Since this is SVGP) In my implementation currently, b plays the role of MC samples, and the lengthscales are of shape (c,d). From what I remember last using Pyro GP, it did not support taking the number of MC above 1 in the GPs, which may not be ideal for GPLVM applications with non-Gaussian likelihoods. I am following Hensman et al. 2013/2015 with subsampling the number of data points so N may vary on each batch. Am I confusing the role of the batch dimension here? I am following the PyTorch distributions (sample, batch, event) shape structure.

So broadcasting is done in the __init__() of each kernel in my implementation, allowing the user to put in one of the formats you suggested when initializing the kernel.

Let me know if this makes sense

fehiepsi commented 3 years ago

this also involves modifying the conditional Gaussian function in another file

I think so. Currently, the conditional utility does not support batch GP.

X has shape (b, N, d) but with b being the MC sample dimension

I'm not so sure about the MC sample dimension. What is the output shape of the kernel that you have in mind if lengthscale.shape == (c, d) and X has shape (b, N, d)? (for broadcasting, I think you'll need X.shape == (b, 1, N, d) to get the output with shape (b, c, N, N) but maybe I misunderstand your usage case)

broadcasting is done in the init() of each kernel in my implementation

Can we do it in forward method instead? We don't know the shape of the input until we call the forward method.