cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 557 forks source link

Unable to achieve convergence on both GPU and CPU #2407

Open Rashfu opened 1 year ago

Rashfu commented 1 year ago

Question Description

I am attempting to fit an exact GP regression on a dataset of ~ 10000 points. The train_x is 3×7740×2 (repeated from the base shape 7740×2) and the train_y is 3×7740 where 3 means the batch shape. Specifically, the input consists of 2-dimensional plane positions XY with values ranging from -14 to -18 as decimals. The output is normalized RGB colors with 3 dimensions, ranging from 0 to 1. The three regression tasks from XY to RGB are independent of each other.

When following the Batch GP Regression tutorial: Training on the CPU: The code does not throw any errors, but it fails to converge and slows down as it runs. However, when I multiply the input train_x by 100, the Batch GP converges quickly and performs well. Training on the GPU: The following errors may occur, including NaN loss and NumericalWarning CG terminated. I have tried to multiply the input train_x by 100 or normalize the input data using Min-Max Scaling, but they didn't work. When I set the data and model to be double precision, the NaN loss disappeared, but it became very slow (of course for double precision) and couldn't converge to the right position.

Is there an issue with my input data? It appears that this is indicative of numerical instability in the numerical computations. I guess something went wrong in computing the log likelihood here

preconditioner, precond_lt, logdet_p = self._preconditioner()
if precond_lt is None:
    from ..operators.identity_linear_operator import IdentityLinearOperator

    precond_lt = IdentityLinearOperator(
        diag_shape=self.size(-1),
        batch_shape=self.batch_shape,
        dtype=self.dtype,
        device=self.device,
    )
    logdet_p = 0.0

precond_args = precond_lt.representation()
probe_vectors, probe_vector_norms = self._probe_vectors_and_norms()

func = InvQuadLogdet.apply
inv_quad_term, pinvk_logdet = func(
    self.representation_tree(),
    precond_lt.representation_tree(),
    preconditioner,
    len(precond_args),
    (inv_quad_rhs is not None),
    probe_vectors,
    probe_vector_norms,
    *(list(args) + list(precond_args)),
)
logdet_term = pinvk_logdet
logdet_term = logdet_term + logdet_p

The data can be downloaded from the attached .zip file. data.zip

Thanks in advance !

Here are the details about the data, code and error on GPU.

Data Example

train_x                      train_y
tensor([-15.0223, -14.0026]) tensor([0.5451, 0.2588, 0.3765])
tensor([-16.1318, -14.1548]) tensor([0.5882, 0.3686, 0.4667])
tensor([-16.7716, -14.5253]) tensor([0.6078, 0.3882, 0.4863])
tensor([-15.9107, -14.8165]) tensor([0.5647, 0.3294, 0.4314])
tensor([-14.9211, -15.1249]) tensor([0.6784, 0.4431, 0.5412])
tensor([-14.0937, -15.4103]) tensor([0.6392, 0.3333, 0.3882])
tensor([-17.4703, -15.1177]) tensor([0.6549, 0.4275, 0.5216])
tensor([-15.7863, -17.1730]) tensor([0.8549, 0.5882, 0.6863])
tensor([-15.3097, -17.5722]) tensor([0.8353, 0.5686, 0.6510])
tensor([-14.6760, -17.8014]) tensor([0.8392, 0.5333, 0.6196])

code

....
# train data torch.Size([3, 7740, 2]) torch.Size([3, 7740]) 
# batch shape 3
batch_shape = train_y.shape[-1]
train_x = train_x.unsqueeze(0).repeat(batch_shape, 1, 1).to(device)
train_y = train_y.transpose(0, 1).to(device)

class BatchGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_inputs, train_targets, likelihood, batch_shape, use_ard=False):
        super(BatchGPModel, self).__init__(train_inputs, train_targets, likelihood)

        ard_num_dims = train_inputs.shape[-1] if use_ard else None

        self.shape = torch.Size([batch_shape])
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=self.shape, constant_constraint=gpytorch.constraints.Interval(0.0, 1.0))
        self.base_kernel = gpytorch.kernels.RBFKernel(batch_shape=self.shape, ard_num_dims=ard_num_dims)
        self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel, batch_shape=self.shape)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood(batch_shape=torch.Size([batch_shape])).to(device)
model = BatchGPModel(train_x, train_y, likelihood, batch_shape=batch_shape, use_ard=True).to(device)

training_iter = 100
model.train()
likelihood.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# Loss for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iter):

    optimizer.zero_grad()

    output = model(train_x)

    loss = -mll(output, train_y).sum()

    loss.backward()

    optimizer.step()

    print('Iter %d/%d - Loss: %.3f  mean0: %.3f  mean1: %.3f  mean2: %.3f  noise0: %.3f  noise1: %.3f  noise2: %.3f  ' % (
        i + 1, training_iter, loss.item(),
        model.mean_module.constant[0].item(),
        model.mean_module.constant[1].item(),
        model.mean_module.constant[2].item(),
        model.likelihood.noise[0].item(),
        model.likelihood.noise[1].item(),
        model.likelihood.noise[2].item()
    ))

Error Message

Iter 1/50 - Loss: 2.251  mean0: 0.574  mean1: 0.426  mean2: 0.574  noise0: 0.554  noise1: 0.554  noise2: 0.554  
Iter 2/50 - Loss: 1.910  mean0: 0.643  mean1: 0.415  mean2: 0.583  noise0: 0.437  noise1: 0.437  noise2: 0.437  
Iter 3/50 - Loss: 1.553  mean0: 0.700  mean1: 0.436  mean2: 0.559  noise0: 0.341  noise1: 0.341  noise2: 0.341  
Iter 4/50 - Loss: 1.179  mean0: 0.744  mean1: 0.469  mean2: 0.529  noise0: 0.263  noise1: 0.263  noise2: 0.263  
Iter 5/50 - Loss: 0.792  mean0: 0.775  mean1: 0.502  mean2: 0.515  noise0: 0.201  noise1: 0.201  noise2: 0.201  
Iter 6/50 - Loss: 0.395  mean0: 0.793  mean1: 0.517  mean2: 0.524  noise0: 0.152  noise1: 0.152  noise2: 0.152  
Iter 7/50 - Loss: nan  mean0: 0.799  mean1: 0.506  mean2: 0.544  noise0: 0.114  noise1: 0.114  noise2: 0.114  
Iter 8/50 - Loss: nan  mean0: 0.795  mean1: 0.483  mean2: 0.559  noise0: 0.086  noise1: 0.086  noise2: 0.086  
Iter 9/50 - Loss: nan  mean0: 0.780  mean1: 0.462  mean2: 0.555  noise0: 0.064  noise1: 0.064  noise2: 0.064  
Iter 10/50 - Loss: nan  mean0: 0.756  mean1: 0.461  mean2: 0.538  noise0: 0.047  noise1: 0.048  noise2: 0.047  
Iter 11/50 - Loss: nan  mean0: 0.723  mean1: 0.481  mean2: 0.522  noise0: 0.035  noise1: 0.035  noise2: 0.035  
Iter 12/50 - Loss: nan  mean0: 0.681  mean1: 0.511  mean2: 0.525  noise0: 0.026  noise1: 0.026  noise2: 0.026  
Iter 13/50 - Loss: nan  mean0: 0.633  mean1: 0.532  mean2: 0.546  noise0: 0.019  noise1: 0.019  noise2: 0.019  
Iter 14/50 - Loss: nan  mean0: 0.586  mean1: 0.528  mean2: 0.570  noise0: 0.014  noise1: 0.014  noise2: 0.014  
Iter 15/50 - Loss: nan  mean0: 0.551  mean1: 0.510  mean2: 0.578  noise0: 0.011  noise1: 0.011  noise2: 0.011  
Iter 16/50 - Loss: nan  mean0: 0.534  mean1: 0.504  mean2: 0.565  noise0: 0.008  noise1: 0.008  noise2: 0.008  
Iter 17/50 - Loss: nan  mean0: 0.537  mean1: 0.498  mean2: 0.571  noise0: 0.006  noise1: 0.006  noise2: 0.006  
[/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:337](https://file+.vscode-resource.vscode-cdn.net/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:337): NumericalWarning: CG terminated in 1000 iterations with average residual norm 783129706496.0 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
  warnings.warn(
Iter 18/50 - Loss: nan  mean0: 0.553  mean1: 0.536  mean2: 0.567  noise0: 0.005  noise1: 0.007  noise2: 0.005  
[/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:337](https://file+.vscode-resource.vscode-cdn.net/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:337): NumericalWarning: CG terminated in 1000 iterations with average residual norm nan which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
  warnings.warn(
[/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/operators/added_diag_linear_operator.py:128](https://file+.vscode-resource.vscode-cdn.net/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/operators/added_diag_linear_operator.py:128): NumericalWarning: NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.
  warnings.warn(
Iter 19/50 - Loss: nan  mean0: 0.590  mean1: 0.500  mean2: nan  noise0: 0.005  noise1: 0.009  noise2: nan  
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 21
     17 optimizer.zero_grad()
     19 output = model(train_x)
---> 21 loss = -mll(output, train_y).sum()
     23 loss.backward()
     25 optimizer.step()

File [~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/module.py:31](https://file+.vscode-resource.vscode-cdn.net/home/dell/Codes/paper/second/Gpytorch/~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/module.py:31), in Module.__call__(self, *inputs, **kwargs)
     30 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> 31     outputs = self.forward(*inputs, **kwargs)
     32     if isinstance(outputs, list):
     33         return [_validate_module_outputs(output) for output in outputs]

File [~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64](https://file+.vscode-resource.vscode-cdn.net/home/dell/Codes/paper/second/Gpytorch/~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64), in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params)
     62 # Get the log prob of the marginal distribution
     63 output = self.likelihood(function_dist, *params) # input prior: p(f|X) and output: p(y|X)
---> 64 res = output.log_prob(target) # log p(y|X)
     65 res = self._add_other_terms(res, params)
     67 # Scale by the amount of data we have

File [~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/distributions/multivariate_normal.py:195](https://file+.vscode-resource.vscode-cdn.net/home/dell/Codes/paper/second/Gpytorch/~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/distributions/multivariate_normal.py:195), in MultivariateNormal.log_prob(self, value)
    191 # Get log determininant and first part of quadratic form
    192 # inv_quad = (K+\sigma^2 I)^{-1}
...
    201 # Sometime we're lucky and the preconditioner solves the system right away
    202 # Check for convergence
    203 residual_norm = residual.norm(2, dim=-2, keepdim=True)

RuntimeError: NaNs encountered when trying to perform matrix-vector multiplication
gpleiss commented 1 year ago

Have you tried using a smaller learning rate? I don't think that this is a bug in GPyTorch, since we are using completely identical code for CPU and GPU.

Rashfu commented 1 year ago

Thank you for your rely.

I have tried smaller learning rates, starting from 0.3 and gradually decreasing to 0.01, but it still results in NaN values. As the learning rate decreases, the model fails to learn anything.

I simplified my question. Now we are using completely identical code for CPU and GPU. I only added the code

if torch.cuda.is_available():
        train_x = train_x.cuda()
        train_y = train_y.cuda()
        model = model.cuda()
        likelihood = likelihood.cuda()

The code runs fine on CPU, but on GPU, it throws a NumericalWarning: CG terminated. Doesn't this mean that GPyTorch has a bug?

This is a simple sample file. I hope you have time to take a look at it. test.zip

Many thanks !

gpleiss commented 1 year ago

I'm not going to open up your zip sample file. If you can post a small reproducible example in the chat here, then I will take a look.

Rashfu commented 1 year ago

Here is a simple example (using GP for super-resolution). I use the pixel coordinates XY of the image as train_x and RGB values as train_y. Everything works fine when I remove the .cuda() code.

# ori_image 60 × 60 resolution
image_tensor = transforms.ToTensor()(ori_image)
image_tensor = image_tensor.unsqueeze(0)

b, _, h, w = image_tensor.shape
x = np.arange(w)*2
y = np.arange(h)*2
X, Y = np.meshgrid(x, y)
sample_x = torch.from_numpy(np.stack([X, Y], axis=-1).reshape(-1, 2))
sample_img = image_tensor.squeeze(0).reshape(3, -1).transpose(0, 1)

batch_shape = sample_img.shape[-1]
train_x = sample_x.unsqueeze(0).repeat((batch_shape, 1, 1))
train_y = sample_img.transpose(0, 1)
# [3, 3600, 2], [3, 3600]
print('train_x shape:', train_x.shape, 'train_y shape:', train_y.shape)

class BatchGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_inputs, train_targets, likelihood, batch_shape, use_ard=False):
        super(BatchGPModel, self).__init__(train_inputs, train_targets, likelihood)

        ard_num_dims = train_inputs.shape[-1] if use_ard else None

        self.shape = torch.Size([batch_shape])
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=self.shape, constant_constraint=gpytorch.constraints.Interval(0.0, 1.0))
        self.base_kernel = gpytorch.kernels.RBFKernel(batch_shape=self.shape, ard_num_dims=ard_num_dims)
        self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel, batch_shape=self.shape)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize the likelihood and prior, batch shape depends on the dimension of y (e.g. RGB image has 3 channels)
likelihood = gpytorch.likelihoods.GaussianLikelihood(batch_shape=torch.Size([batch_shape]))

model = BatchGPModel(train_x, train_y, likelihood, batch_shape=batch_shape, use_ard=True)

if torch.cuda.is_available():
    train_x = train_x.cuda()
    train_y = train_y.cuda()
    model = model.cuda()
    likelihood = likelihood.cuda()

# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(50):

    optimizer.zero_grad()

    output = model(train_x)

    loss = -mll(output, train_y).sum()
    loss.backward()

    optimizer.step()

    print('Iter %d/%d - Loss: %.3f  mean0: %.3f mean1: %.3f  mean2: %.3f noise0: %.3f  noise1: %.3f  noise2: %.3f' % (
        i + 1, 50, loss.item(),
        model.mean_module.constant[0].item(),
        model.mean_module.constant[1].item(),
        model.mean_module.constant[2].item(),
        model.likelihood.noise[0].item(),
        model.likelihood.noise[1].item(),
        model.likelihood.noise[2].item()
    ))

Otherwise, I get the following error:

/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 81.5545425415039 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
  warnings.warn(
Iter 1/50 - Loss: 2.826  mean0: 0.525 mean1: 0.475  mean2: 0.475 noise0: 0.644  noise1: 0.744  noise2: 0.744
/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 14.917732238769531 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
  warnings.warn(
Iter 2/50 - Loss: 2.931  mean0: 0.550 mean1: 0.469  mean2: 0.466 noise0: 0.607  noise1: 0.779  noise2: 0.780
/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 89.217041015625 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
  warnings.warn(
Iter 3/50 - Loss: 2.931  mean0: 0.574 mean1: 0.475  mean2: 0.465 noise0: 0.581  noise1: 0.772  noise2: 0.809
/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 456.7450866699219 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
  warnings.warn(
Iter 4/50 - Loss: 2.898  mean0: 0.598 mean1: 0.486  mean2: 0.469 noise0: 0.553  noise1: 0.768  noise2: 0.832
/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 36.63666915893555 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
  warnings.warn(
Iter 5/50 - Loss: 2.851  mean0: 0.621 mean1: 0.497  mean2: 0.476 noise0: 0.527  noise1: 0.765  noise2: 0.853