cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.52k stars 551 forks source link

[Bug] Multitask Exact GP + RBF kernel possibly malfunctioning #1056

Closed YukiyaSaito closed 4 years ago

YukiyaSaito commented 4 years ago

Context

I have been trying to create a GP emulator for a computer model which maps 354-dimensional input onto 30-dimensional output. For this purpose, I tried using Multitask Exact GP with RBF kernel. However, when I use a validation data set to see the generalization capability of this model, the model returns basically the same output for any input, with the whole scale being slightly varied.

The figure below shows this behaviour. If the GP emulator is perfect, all the points would lie on the y=x line. RBF

Whereas if I use Matern kernel, prediction of the model is reasonably good (or bad): Matern

I cannot quite understand this behaviour with my limited experience, therefore I would appreciate any comment or insight regarding this issue.

To reproduce

My model is pretty much identical to the example shown in the Multitask Exact GP notebook. Everything is run on Google Colab with a GPU.

Install GPyTorch on Colab and Load packages

Installed version of GPyTorch is 1.0.1, and PyTorch version is 1.4.0 on Google Colab.

import math
import torch
!pip install gpytorch
import gpytorch
from matplotlib import pyplot as plt

%matplotlib inline
%load_ext autoreload
%autoreload 2

Load data

My training data is here: https://drive.google.co/open?id=1Tu3IUN6Xryi_wgtbBOhOY5sRCAf9h0-H

import numpy as np
loaded = np.load('data_train.npz') #load data from data_train.npz
lhd = loaded['s1n_var']
abA_short_label = loaded['label']
abA_short_data = loaded['data']

numsize = 10000 #train with 10000 points

train_x = torch.empty(size=(numsize,354))
train_y = torch.empty(size=(numsize, 30))

for i in range(354):
    train_x[:,i] = torch.from_numpy(lhd[0:numsize,i])
for i in range(30):
    train_y[:,i] = torch.from_numpy(abA_short_data[i,0:numsize]*10**3)

print(train_x.shape)
print(train_y.shape)

Model definition

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(), num_tasks=30
        )
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            gpytorch.kernels.RBFKernel(), num_tasks=30, rank=1
            #gpytorch.kernels.MaternKernel(nu=1.5), num_tasks=30, rank=1
        )
        #self.covar_module.initialize_from_data(train_x, train_y)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=30)
model = MultitaskGPModel(train_x, train_y, likelihood)

train_x = train_x.cuda()
train_y = train_y.cuda()
model = model.cuda()
likelihood = likelihood.cuda()

training

import os
smoke_test = ('CI' in os.environ)
training_iterations = 2 if smoke_test else 500

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam([
    {'params': model.parameters()},  # Includes GaussianLikelihood parameters
], lr=0.01)

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iterations):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f  noise: %.3f' % (i + 1, training_iterations, loss.item(), model.likelihood.noise.item()))
    optimizer.step()

Load validation data

My validation data is here: https://drive.google.com/open?id=12sC7uCSxayZEEDK-TtIlF43o5DU-SB5p

import numpy as np
loaded = np.load('./valid.npz') #load data from valid.npz
lhd_valid = loaded['valid_s1n']
valid_label = loaded['valid_label']
valid_data = loaded['valid_data']

test_x = torch.empty(size=(2000,354))
test_y = torch.empty(size=(2000, 30))

for i in range(354):
    test_x[:,i] = torch.from_numpy(lhd_valid[0:2000,i])
for i in range(30):
    test_y[:,i] = torch.from_numpy(valid_data[i,0:2000]*10**3)

test_x = test_x.cuda()
test_y = test_y.cuda()

print(test_x.shape) #validation with 2000 points
print(test_y.shape)

How my plots are generated

fig, axs = plt.subplots(6,5,figsize=(30, 25))
for i in range(6):
  for j in range(5):
    axs[i,j].plot(pred_valid[:,5*i+j].cpu(), test_y[:,5*i+j].cpu(),'o',markersize=2)
    axs[i,j].plot([0,20], [0,20], 'k-', alpha=0.75, zorder=0)
    axs[i,j].set_title('$abundance$, $A = %d$'%valid_label[5*i+j,0])
    axs[i,j].set_xlim(0,0.2)
    axs[i,j].set_ylim(0,0.2)
for ax in axs.flat:
    ax.set(xlabel='Prediction', ylabel='PRISM')

# Hide x labels and tick labels for top plots and y ticks for right plots.
for ax in axs.flat:
    ax.label_outer()
plt.show()

Expected Behavior

Expected output would be something similar to the output with Matern kernel shown in the second figure above.

System information

Please complete the following information:

gpleiss commented 4 years ago

@YukiyaSaito can you describe what kind of data you are trying to model? My guess is that this is not a bug at all, but rather your data isn't well suited for the RBF kernel. The RBF kernel can only learn functions that are infinitely differentiable. If you data aren't that smooth, then the RBF kernel usually does really bad.

It could also be the case that you' want to use a different initialization for the RBF kernel lengthscale (sometimes this helps, sometimes it doesn't). You can try adding this line after you initialize the model:

model.covar_module.data_covar_module.initialize(lengthscale=0.1)

However, my best guess is that you need the Matern kernel for your particular problem. (In my personal opinion, the Matern kernel is a much better "all-around" kernel than the RBF kernel anyways).

Thank you for the code/data. However, for future reference, it is much better if you package your code in an ipython notebook for us (this makes it much faster for us to debug what's going on).

YukiyaSaito commented 4 years ago

@gpleiss Thank you very much for your prompt response. The data I am trying to model is the response of numerical solutions to a (large) system of ordinary differential equations (ODEs), when input variables are slightly varied.

Input (training input) and output (training output) roughly have the properties described below: Input: 354-dimentional vector. Generated with Latin Hypercube Sampling. The values are somewhere between [0, 1]. Output: 30-dimentional vector. Results of some numerical simulation. The values are somewhere between [1e-6, 5e-4], therefore they are multiplied by 10^3 when training for numerical stability. Each component of the 30-dimentional vector has slightly different scale.

[Edit] I forgot to mention about the smoothness of the data. As the output is the solutions to the ODEs for slightly varied input variables, I would expect the data to be somewhat smooth. But given the results with RBF kernel, this assumption might be wrong.

Thank you for your suggestion about the initialization for the lengthscale. I have quickly tried different size of RBF kernel lengthscale, but it does not seem to work.

I do not have very good intuitions as to which kernel to use. While Matern kernel seems to do a decent job, but is there a good approach to fine tune the kernel so that the systematic deviations from the target output can be fixed? I tried adding different kernels (Matern + Linear), but it seems to become computationally heavy very quickly.

I am sorry about how I presented my code and data. Here is the link to my ipython notebook. You should be able to run this without any modification. Please note that the prediction part (2000 points) takes ~5 minutes, therefore you can reduce the number of points down to 100 for a quick check. https://colab.research.google.com/github/YukiyaSaito/PRISM_Emulator/blob/master/GP.ipynb

gpleiss commented 4 years ago

I'm going to go ahead and say this not a bug. I think that the data you have are probably more suited for the mater kernel.

"is there a good approach to fine tune the kernel so that the systematic deviations from the target output can be fixed?"

You could try increasing the rank on the multitask kernel to something >1? Otherwise, I think the best approach is trial and error with different kernel types (this is what I would do with the data you have).