cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.53k stars 554 forks source link

[Bug] "fast_computations" going slower #2078

Closed jimmyrisk closed 1 year ago

jimmyrisk commented 2 years ago

🐛 Bug

I am finding that including with gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False): unexpectedly improves runtime by 5x (and produces different MLL value).

I will provide the full reproducible code at the bottom, but here is a rough explanation of what I am encountering. For reference, train_x is 1050x3, and train_y is 1050x1.

Normal Settings

start_time = time.time()

likelihood_fit = gpytorch.likelihoods.GaussianLikelihood(noise_constraint = Interval(1e-5, 1, initial_value = 1e-3))
model_fit = ExactGPModel(train_x, train_y, likelihood_fit, covar_module)
model_fit.train()

hypers = {
    'mean_module.weights': torch.tensor([[3.4]]).cuda(),
    'mean_module.bias': torch.tensor(-5.0).cuda()
}
model_fit.initialize(**hypers)

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood_fit, model_fit)
optimizer = torch.optim.Adam(model.parameters(), lr=0.2)  # Includes GaussianLikelihood parameters
optimizer.param_groups[0]['capturable'] = True

training_iter = 100
for i in range(training_iter):
    optimizer.zero_grad()
    output = model_fit(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    optimizer.step()

print(time.time() - start_time)

> 5.186566114425659

With gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False):

start_time = time.time()

with gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False):
    likelihood_fit = gpytorch.likelihoods.GaussianLikelihood(noise_constraint = Interval(1e-5, 1, initial_value = 1e-3))
    model_fit = ExactGPModel(train_x, train_y, likelihood_fit, covar_module)
    model_fit.train()

    hypers = {
        'mean_module.weights': torch.tensor([[3.4]]).cuda(),
        'mean_module.bias': torch.tensor(-5.0).cuda()
    }
    model_fit.initialize(**hypers)

    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood_fit, model_fit)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.2)  # Includes GaussianLikelihood parameters
    optimizer.param_groups[0]['capturable'] = True

    training_iter = 100
    for i in range(training_iter):
        optimizer.zero_grad()
        output = model_fit(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        optimizer.step()

print(time.time() - start_time)
> 1.3557474613189697

Differences in mll

print(-mll(output, train_y))

with gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False):
    print(-mll(output, train_y))

> tensor(-2.1248, grad_fn=<NegBackward0>)
> tensor(-2.2670, grad_fn=<NegBackward0>)

Expected Behavior

As documented, gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False) utilizes Cholesky decompositions which in turn (I believe) is supposed to increase accuracy at the expense of increasing runtime.

System information

Please complete the following information:

Additional context

The purpose of this simulation is to generate a training set akin to that we use in our mortality modelling research (hence age year cohort), and pick a plausible kernel, thereafter simulating (from the prior) synthetic mortality y's, in which, we try to recover the plausible kernel through comparing likelihoods.

We noticed varying mll computations and tried a few fixes as documented here:

Eventually, I tried gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False) which, to my astonishment, not only ran much faster, but produced different mll results even when trying massive values for other settings (e.g. with gpytorch.settings.num_trace_samples(1050) and with gpytorch.settings.max_preconditioner_size(1050).

This was also tested with/without torch.backends.cuda.matmul.allow_tf32 = True as recommended in https://github.com/cornellius-gp/gpytorch/issues/1960.

Questions

  1. I assume the runtime change is not expected? Or am I missing something? Any idea what is going on?
  2. Despite a runtime change, I am seeing differing mll values. Should I trust the one produced purely by Cholesky (i.e. gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False)), or the one with normal settings (with large trace_samples and preconditioner_size)?
  3. Is the answer to 2 case-by-case dependent, or universal? I am suspecting some of our issues have to do with small noise values.

Full code for reproducibility

Sorry for the massive amount of code, but I knew this example specifically gives the error, so I tried to make it self-contained.

#%%
import gpytorch
import torch
import time

from gpytorch.kernels import ScaleKernel, MaternKernel
from gpytorch.means import Mean
from gpytorch.constraints import Interval

torch.set_default_tensor_type('torch.cuda.FloatTensor')

#%%
class MinMaxScaler:
    def __init__(self, mins = None, maxs = None):
        self.mins = mins
        self.maxs = maxs

    def fit(self, X):
        self.mins, _ = torch.min(X, dim=0)
        self.maxs, _ = torch.max(X, dim=0)

    def scale(self, X):
        X_scaled = (X - self.mins) / (self.maxs - self.mins)
        return X_scaled

    def unscale(self, X_scaled):
        X = X_scaled * (self.maxs - self.mins) + self.mins
        return X

#%%
ag_start = 50
ag_end = 84
yr_start = 1990
yr_end = 2019

seed = 0

torch.manual_seed(seed)

# generate age year grid

ags = torch.linspace(ag_start, ag_end, steps=ag_end - ag_start + 1).cuda()
yrs = torch.linspace(yr_start, yr_end, steps=yr_end - yr_start + 1).cuda()

x_dat = torch.cartesian_prod(ags, yrs).cuda()

cohorts = x_dat[:, 1] - x_dat[:, 0]
cohorts = cohorts.reshape(-1, 1)

sim_x = torch.cat((x_dat, cohorts), axis=-1)
sim_x = sim_x.cuda()

scaler = MinMaxScaler()
scaler.fit(sim_x)
sim_x_std = scaler.scale(sim_x)

print(sim_x.shape)  # 1050x3

#%%
covar_module = (
    ScaleKernel(MaternKernel(nu=0.5, active_dims = torch.tensor([0]))) +
    ScaleKernel(MaternKernel(nu=0.5, active_dims = torch.tensor([1])) *
                MaternKernel(nu=1.5, active_dims = torch.tensor([0]))) +
    ScaleKernel(MaternKernel(nu=1.5, active_dims = torch.tensor([2])))
)

#%%
class LinearMean_1d(Mean):
    def __init__(self, input_size, active_dim, batch_shape=torch.Size(), bias=True):
        super().__init__()
        self.active_dim = active_dim
        self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1).cuda()))
        if bias:
            self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1).cuda()))
        else:
            self.bias = None

    def forward(self, x):
        x = x.index_select(-1, self.active_dim)
        res = x.matmul(self.weights).squeeze(-1)
        if self.bias is not None:
            res = res + self.bias
        return res

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, covar_module):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = LinearMean_1d(active_dim = torch.tensor([0]).cuda(),
                                                        input_size = 1)
        self.covar_module = covar_module

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

#%%
# Simulate data
d = 3
init_x = torch.randn(0, d)
init_y = torch.randn(0)

likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint = Interval(1e-5, 1, initial_value = 1e-3))
model = ExactGPModel(init_x, init_y, likelihood, covar_module)

hypers = {
    'likelihood.noise_covar.noise': torch.tensor([0.0003]),
    'mean_module.weights': torch.tensor([[3.5]]).cuda(),
    'mean_module.bias': torch.tensor(-6.0).cuda()
}

#%%
model.covar_module.kernels[0].outputscale = 0.1
model.covar_module.kernels[0].base_kernel.lengthscale = 4.0
model.covar_module.kernels[1].outputscale = 0.2
model.covar_module.kernels[1].base_kernel.kernels[0].lengthscale = 4.5
model.covar_module.kernels[1].base_kernel.kernels[1].lengthscale = 1.0
model.covar_module.kernels[2].outputscale = 0.1
model.covar_module.kernels[2].base_kernel.lengthscale = 0.5

model.initialize(**hypers)

#%%
model.eval()
mvn_y = likelihood(model(sim_x_std))
with torch.no_grad():
    sim_y = mvn_y.rsample()
    sim_y = sim_y.cuda()

#%%
# Set for training
train_x_raw = sim_x
train_x = sim_x_std
train_y = sim_y

#%%
start_time = time.time()

likelihood_fit = gpytorch.likelihoods.GaussianLikelihood(noise_constraint = Interval(1e-5, 1, initial_value = 1e-3))
model_fit = ExactGPModel(train_x, train_y, likelihood_fit, covar_module)
model_fit.train()

hypers = {
    'mean_module.weights': torch.tensor([[3.4]]).cuda(),
    'mean_module.bias': torch.tensor(-5.0).cuda()
}
model_fit.initialize(**hypers)

mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood_fit, model_fit)
optimizer = torch.optim.Adam(model.parameters(), lr=0.2)  # Includes GaussianLikelihood parameters
optimizer.param_groups[0]['capturable'] = True

training_iter = 100
for i in range(training_iter):
    optimizer.zero_grad()
    output = model_fit(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    optimizer.step()

print(time.time() - start_time)
print(loss)

#%%
start_time = time.time()

with gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False):
    likelihood_fit = gpytorch.likelihoods.GaussianLikelihood(noise_constraint = Interval(1e-5, 1, initial_value = 1e-3))
    model_fit = ExactGPModel(train_x, train_y, likelihood_fit, covar_module)
    model_fit.train()

    hypers = {
        'mean_module.weights': torch.tensor([[3.4]]).cuda(),
        'mean_module.bias': torch.tensor(-5.0).cuda()
    }
    model_fit.initialize(**hypers)

    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood_fit, model_fit)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.2)  # Includes GaussianLikelihood parameters
    optimizer.param_groups[0]['capturable'] = True

    training_iter = 100
    for i in range(training_iter):
        optimizer.zero_grad()
        output = model_fit(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        optimizer.step()

print(time.time() - start_time)
print(loss)

#%%
print(-mll(output, train_y))

with gpytorch.settings.fast_computations(covar_root_decomposition=False, log_prob=False, solves=False):
    print(-mll(output, train_y))
#%%

Thanks!

gpleiss commented 2 years ago

A couple of thoughts:

All of the behavior that you report is somewhat expected. 5x slower seems like a bit more than what I've seen in the past, but you are using a composite kernel which generally slows down the performance of CG.

Once you scale up beyond n=1000, you will notice that CG starts becoming faster than Cholesky. See what happens when I bump up the data size in your experiment (and set the preconditioner size to 1000):

Data shape: torch.Size([7050, 3])
Time with fast computations off:  7.5905442237854
Cholesky loss (trial 1): -2.459498167037964
Time with fast computations on:  3.1095833778381348
Fast computations loss (trial 1): -2.460662841796875
Fast computations loss (trial 2): -2.4602770805358887
Fast computations loss (trial 3): -2.460109233856201

The difference between Cholesky and the CG-based approaches is due to 1) the stochasticity of the CG approach, and 2) the bias that's introduced by CG (see this paper). We find that these differences can be significant if you're running for very few CG iterations, but usually don't have a huge impact when you let CG run for a while.

The data that you are simulating seem to be ill conditioned and do not respond well to the default pivoted cholesky preconditioner. I think this could be due to the low dimensionality and the limited amount of observational noise.

All in all, the CG-based inference code is some of the best tested/maintained parts of our codebase, and so I don't think this behavior represents a bug. However, I will contend that we should do the following:

@jacobrgardner thoughts?

laurence-kobold commented 1 year ago

I've run into a similar issue. I'm seeing much faster predictions when fast_computations is turned off. Here's a minimal test case to reproduce the issue:

import gpytorch
import torch
import time

# Test data
torch.manual_seed(0)
train_x = torch.randn(4000, 2)
train_y = torch.randn(4000)
test_x = torch.randn(5100, 2)

# Construct model
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
model = ExactGPModel(train_x, train_y, gpytorch.likelihoods.GaussianLikelihood())

# Train model
model.train()
model.likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
for i in range(50):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    optimizer.step()
model.eval()

start_time = time.time()
preds = model.likelihood(model(test_x))
print(time.time() - start_time)

start_time = time.time()
with gpytorch.settings.fast_computations(solves=False):
    preds = model.likelihood(model(test_x))
print(time.time() - start_time)

With fast computations turned off, predictions are calculated approximately 15 times faster. Similarly to @jimmyrisk's example, this example uses a small number of dimensions (2), but it uses a larger number of training points (4000)

gpleiss commented 1 year ago

Hmm okay maybe we should re-think some of our conditions for when we switch to using CG.

gpleiss commented 1 year ago

I'm going to open up another issue to document this.

gpleiss commented 1 year ago

Ahhh @laurence-kobold part of the problem is that you are not using the fast_pred_var context manager (when fast solves is off). This makes the two times much more comparable, but we should still adjust our internal logic for when we use Cholesky versus when we use iterative methods.