cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 557 forks source link

[Docs]Is there a example script to train a GP layer on top of LSTM #686

Open shgaurav1 opened 5 years ago

shgaurav1 commented 5 years ago

Hi,

Is there a GPyTorch implementation of this paper? If there exist some similar script in the examples, can you please direct me to that?

Thanks

gpleiss commented 5 years ago

We don't have an example for this currently. I imagine the easiest way to implement it would be to take what is currently in the CIFAR10 example, but replace the feature extractor with an LSTM.

shgaurav1 commented 5 years ago

@gpleiss I tried doing that. Here, is my code

class GPRegressionLayer(AbstractVariationalGP):
    def __init__(self, inducing_points):
        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = WhitenedVariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)
        super(GPRegressionLayer, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(), num_tasks=16
        )
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            gpytorch.kernels.RBFKernel(), num_tasks=16, rank=1
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

as my task involves creating a mapping from input sequence space (16 dims) to output space (also 16 dim) using a GP-LSTM. I tried using the variational GP to train a GP layer on top of LSTM. However, the above code is giving some error with dimensionality. I debugged abit and came to conclusion that either the multitask kernel or the multitask mean is creating some issue. Can you please suggest some fix?

Thanks

jacobrgardner commented 5 years ago

Do you truly have multitask data (e.g. are your labels n x t)? You'll need to return a MultitaskMultivariateNormal if so rather than a standard one. if you don't have multitask data, don't use the multitask mean/covar.

What is the error you are getting, can you paste a stack trace? It is at least possible that there is some problem with the multitask kernel in variational models on our end.

shgaurav1 commented 5 years ago

@jacobrgardner My mini batch input shape looks like [80,16] with target also having same size [80,16]. Hence, I thought using a multitask mean/covar was appropriate here. Please do let me know if I missed something important?

Thanks

Traceback (most recent call last):                                                                                                                                                                       |
  File "gp_lstm_version_4.py", line 433, in <module>
    mse = train_GP_Frame_predictor(x)
  File "gp_lstm_version_4.py", line 269, in train_GP_Frame_predictor
    h_pred = gp_layer(frame_predictor(h))
  File "/cfarhomes/gauravsh/myenv/lib/python3.6/site-packages/gpytorch/models/abstract_variational_gp.py", line 22, in __call__
    return self.variational_strategy(inputs)
  File "/cfarhomes/gauravsh/myenv/lib/python3.6/site-packages/gpytorch/variational/variational_strategy.py", line 190, in __call__
    self.initialize_variational_dist()
  File "/cfarhomes/gauravsh/myenv/lib/python3.6/site-packages/gpytorch/variational/whitened_variational_strategy.py", line 63, in initialize_variational_dist
    .type_as(prior_dist.covariance_matrix),
  File "/cfarhomes/gauravsh/myenv/lib/python3.6/site-packages/torch/distributions/multivariate_normal.py", line 115, in __init__
    self.covariance_matrix, loc_ = torch.broadcast_tensors(covariance_matrix, loc_)
  File "/cfarhomes/gauravsh/myenv/lib/python3.6/site-packages/torch/functional.py", line 52, in broadcast_tensors
    return torch._C._VariableFunctions.broadcast_tensors(tensors)
RuntimeError: The size of tensor a (640) must match the size of tensor b (16) at non-singleton dimension 1

PS: inducing points input shape is [40,16]

shgaurav1 commented 5 years ago

Do you truly have multitask data (e.g. are your labels n x t)? You'll need to return a MultitaskMultivariateNormal if so rather than a standard one. if you don't have multitask data, don't use the multitask mean/covar.

@jacobrgardner Even after using MultitaskMultivariateNormal the problem still persists.

jacobrgardner commented 5 years ago

@the-darklord Could you try using a batch model instead of a multitask model? Here's an example notebook using batch mode to learn 4 functions: https://github.com/cornellius-gp/gpytorch/blob/master/examples/01_Simple_GP_Regression/Simple_Batch_Mode_GP_Regression.ipynb

Basically, try treating each of the 16 outputs as an independent GP (with covariances learned by the shared LSTM feature extractor). I think for variational inference, multitask inference will need to be handled a slightly different way, e.g. by mixing the f outputs directly.

shgaurav1 commented 5 years ago

@jacobrgardner Thanks that really helped !!

adam-rysanek commented 5 years ago

@the-darklord . Did you figure out a way to implement your problem? I'm interested to learn more about how to implement GP-LSTM with gpytorch, but not exactly sure if I'm on the right starting point. At least in our case, we have a single output GP that is easier to handle, but I'm a bit stuck on how to implement the LSTM alongside the GP in gpytorch syntax.

jacobrgardner commented 5 years ago

@adam-rysanek Do you have specific problems you are encountering with this? In general, the input to your GP model can be the output of any torch.nn.Module, so GPyTorch doesn't particularly care whether that's a fully connected net, a conv net, or an LSTM so long as you define the dimensionality of your inducing points to appropriately be the output dimensionality of the feature extractor / module.

adam-rysanek commented 5 years ago

Hi @jacobrgardner . Yes, I'm now realising this. I need to first / separately understand how to set up an LSTM in torch. I've discovered #410 , which I'm using as a guide to get me started. Thanks for responding. This could be a long-term request, but it given the potential for gpytorch to help with time-series problems, it could be a long-term boost if an iteration of a DKL-GP example could tackle a time-series problem with an RNN/LSTM-GP.