cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.57k stars 560 forks source link

[Bug] Serious bug of "gpytorch.distributions" #2582

Closed yzshi5 closed 1 month ago

yzshi5 commented 1 month ago

🐛 Bug

When I worked on creating a Gaussian Process (GP) with a Matern kernel to sample from it, I found it was incorrect. I validated it against the ground truth.

To reproduce

Code snippet to reproduce


import gpytorch # version = 1.12
import torch 
import numpy as np

# for comparison
from sklearn.gaussian_process.kernels import Matern 
from scipy.stats import binned_statistic
import matplotlib.pyplot as plt

## Here, we want to to create a 2D Gaussian Random Field and sample from it
## The generated samples should have shape of [N, 1, 64, 64], where N is the number of samples, 64x64 is the resolution

n_x = 64 # resolution 

kernel_length= 0.01 #0.1, 0.05, 0.1
kernel_variance=1
nu = 0.5 # default

train_samples = 1000
n_channels = 1
dims = [n_x, n_x]

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
## #query_points = make_grid(dims)

x = np.linspace(0, 1, n_x)
y = np.linspace(0, 1, n_x)
XX, YY = np.meshgrid(x, y)
XX = XX.reshape(-1, 1)
YY = YY.reshape(-1, 1)

query_points = torch.Tensor(np.concatenate([XX, YY], axis=1))

eps = 1e-10
#eps = 0

#### Define the  gpytorch kernel
base_kernel = gpytorch.kernels.MaternKernel(nu, eps=eps)
base_kernel.lengthscale = kernel_length

mean_module = gpytorch.means.ConstantMean()
#covar_module = gpytorch.kernels.ScaleKernel(base_kernel)
covar_module = base_kernel
#covar_module.outputscale = 1

mean_x = mean_module(query_points)
covar_x = covar_module(query_points)

gpy_dist = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
gpy_samples = gpy_dist.sample(sample_shape = torch.Size([train_samples * n_channels, ]))
gpy_samples = gpy_samples.reshape(train_samples, n_channels, *dims)

plt.imshow(gpy_samples[0].reshape(dims))

####################################### comparison ###############################

def matern_kernel_cov(ndim, length_scale, nu):
    x = np.linspace(0, 1, ndim)
    y = np.linspace(0, 1, ndim)
    XX, YY = np.meshgrid(x, y)
    XX = XX.reshape(-1, 1)
    YY = YY.reshape(-1, 1)
    X = np.concatenate([XX, YY], axis=1)
    print(X.shape)
    kernel = 1.0 * Matern(length_scale=length_scale, length_scale_bounds="fixed", nu=nu)
    return kernel(X)

matern_ker = matern_kernel_cov(64, kernel_length, nu)

base_mu = torch.zeros(n_x*n_x).float()
base_cov = torch.tensor(matern_ker).float()
base_dist = torch.distributions.MultivariateNormal(base_mu.to(device), scale_tril=torch.linalg.cholesky_ex(base_cov.to(device))[0])

sklearn_grf_samples = base_dist.sample([train_samples])

print("log_p :{}".format(base_dist.log_prob(sklearn_grf_samples[:5])))

plt.imshow(sklearn_grf_samples[0].reshape([n_x,n_x]).detach().cpu())

## statistical comparison
def compute_acovf(z):
    # compute the autocovariance based on FFT method
    # z shape : [n, ndim, ndim]
    res = z.shape[-1]
    z_hat = torch.fft.rfft2(z)
    acf = torch.fft.irfft2(torch.conj(z_hat) * z_hat)
    acf = torch.fft.fftshift(acf).mean(dim=0) / z[0].numel() # ndim*ndim
    acf_r = acf.view(-1).cpu().detach().numpy()
    lags_x, lags_y = torch.meshgrid(torch.arange(res) - res//2, torch.arange(res) - res//2)
    lags_r = torch.sqrt(lags_x**2 + lags_y**2).view(-1).cpu().detach().numpy()

    idx = np.argsort(lags_r)
    lags_r = lags_r[idx]
    acf_r = acf_r[idx]

    bin_means, bin_edges, binnumber = binned_statistic(lags_r, acf_r, 'mean', bins=np.linspace(0.0, res, 50))
    return bin_edges[:-1], bin_means

gpy_samples = gpy_samples.squeeze()
baseline_samples = sklearn_grf_samples.detach().cpu().reshape(-1, *dims)
with torch.no_grad():

    X_hat = gpy_samples[:5]
    X_ground_truth = baseline_samples[:5].squeeze()

    bin_center, x_acovf = compute_acovf(gpy_samples.squeeze())
    _, x_acovf_true = compute_acovf(baseline_samples.squeeze())

    x_hist, bin_edges_alt = gpy_samples.histogram(range=[-4,4], density=True)
    x_hist_true, bin_edges = baseline_samples.histogram(range=[-4, 4], density=True)

    fig, ax = plt.subplots(1,5, figsize=(15,3))
    for i in range(5):
        x = X_hat[i,:,:].squeeze()

        ax[i].imshow(x)#, vmin=-2, vmax=2)
        if i == 0:
            ax[i].set_ylabel('Gpytorch', fontsize=16)

    #cb = fig.colorbar(ax[4].imshow(), , orientation='vertical')
    bar = ax[4].imshow(x)#, vmin=-2, vmax=2)
    fig.colorbar(bar, ax=ax)
    #plt.show()

    fig, ax = plt.subplots(1,5, figsize=(15,3))    
    for i in range(5):
        x_ground_truth = X_ground_truth[i,:,:].squeeze()
        ax[i].imshow(x_ground_truth)#, vmin=-2, vmax=2)
        if i == 0:
            ax[i].set_ylabel('Ground Truth', fontsize=16)

    bar = ax[4].imshow(x_ground_truth) #, vmin=-2, vmax=2)
    fig.colorbar(bar, ax=ax)

    fig, ax = plt.subplots(1,2, figsize=(12,4))
    ax[0].plot(bin_center, x_acovf_true, c='k', lw=3)
    ax[0].plot(bin_center, x_acovf, c='r',ls='--', lw=3)
    #ax[0].set_ylim(0.2, 0.45)
    ax[0].set_title('Autocov, l={}, nu={}'.format(kernel_length,nu))
    ax[0].set_xlabel('Position', fontsize='large')
    ax[1].plot((bin_edges[1:]+bin_edges[:-1])/2, x_hist_true, c='k', lw=3, label='Gpytorch')
    ax[1].plot((bin_edges_alt[1:]+bin_edges_alt[:-1])/2, x_hist, c='r', ls='--', lw=3, label='Ground Truth')
    ax[1].set_title('Histogram')
    ax[1].legend(loc='upper right')
    ax[1].set_xlabel('Value', fontsize='large')

Stack trace/error message

// Paste the bad output here!

Expected Behavior

The autocovariance at position 0 supposed to be 1, when lengthscale is smooth enough, we expect to get the pattern similar to white noise. image

System information

Please complete the following information:

Additional context

Add any other context about the problem here.

gpleiss commented 1 month ago

There is a bug in your reprodicible example:

base_dist = torch.distributions.MultivariateNormal(base_mu.to(device), scale_tril=torch.linalg.cholesky_ex(base_cov.to(device))[0])

You shouldn't have a [0] after the cholesky factor; that only grabs the first row of the Cholesky factor. If you remove the [0] your samples will look more identical.

Given that you are computing samples in the domain [0, 1] x [0, 1], I would expect the values that GPyTorch produces even with a lengthscale of 0.01. It shouldn't look like white noise until the lengthscale is much smaller. (You can verify this by noticing that the gpytorch covariance matrix - which is identical to the sklearn covariance matrix - is far from the identity and therefore not white noise).

yzshi5 commented 1 month ago

Hi Geoff, thanks for your timely reply, torch.linalg.cholesky_ex(base_cov.to(device)) return a Lower triangle and info

torch.linalg.cholesky_ex(base_cov.to(device))
L=tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [9.6847e-01, 2.4913e-01, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [8.9432e-01, 4.1083e-01, 1.7725e-01,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [8.4983e-10, 7.9180e-10, 7.7648e-10,  ..., 9.7107e-02, 0.0000e+00,
         0.0000e+00],
        [7.0670e-10, 6.6394e-10, 6.5171e-10,  ..., 9.4278e-02, 9.7856e-02,
         0.0000e+00],
        [5.8675e-10, 5.5572e-10, 5.4600e-10,  ..., 6.0951e-02, 1.0021e-01,
         1.0890e-01]], device='cuda:1'),
info=tensor(0, device='cuda:1', dtype=torch.int32))

so [0] just extract the L matrix. If we remove [0], we will receive errors

The other thing is if we look at the histogram of the pointwise distribution of the generated samples, it shouldn't be bounded to [-1, 1]. and for all matern kernels, if the variance of the kernel is 1, we expect to see the autocovariance at position 0 is 1.

gpleiss commented 1 month ago

Your bug needs to be more specific. Do you think the issue is with the sampling code, or with computing the covariance matrix? From what I can see, the sklearn and the gpytorch matern kernel functions return identical covariance matrices, so that's not the source of the bug. Please try to make your reproducible example as small as possible.