cornellius-gp / gpytorch

A highly efficient implementation of Gaussian Processes in PyTorch
MIT License
3.54k stars 557 forks source link

[Bug]Exact GP Regression with Multiple GPUs and Kernel Partitioning #2440

Open XiankangTang opened 10 months ago

XiankangTang commented 10 months ago

🐛 Bug

I ran the code from the example that uses multiple GPUs. https://github.com/cornellius-gp/gpytorch/blob/master/examples/02_Scalable_Exact_GPs/Simple_MultiGPU_GP_Regression.ipynb I used 1 A100 and 3 RTX 3090's to run this string of code, but the programme is buggy.

To reproduce

Code snippet to reproduce

n_devices = 4
print('Planning to run on {} GPUs.'.format(n_devices))

Stack trace/error message

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], line 1
----> 1 model, likelihood = train(train_x, train_y,
      2                           n_devices=n_devices, output_device=output_device,
      3                           preconditioner_size=100,
      4                           n_training_iter=20)

Cell In[7], line 43, in train(train_x, train_y, n_devices, output_device, preconditioner_size, n_training_iter)
     40     loss = -mll(output, train_y)
     41     return loss
---> 43 loss = closure()
     44 loss.backward()
     46 for i in range(n_training_iter):

Cell In[7], line 40, in train.<locals>.closure()
     38 output = model(train_x)
     39 print(output)
---> 40 loss = -mll(output, train_y)
     41 return loss

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/gpytorch/module.py:31, in Module.__call__(self, *inputs, **kwargs)
     30 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> 31     outputs = self.forward(*inputs, **kwargs)
     32     if isinstance(outputs, list):
     33         return [_validate_module_outputs(output) for output in outputs]

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params)
     62 # Get the log prob of the marginal distribution
     63 output = self.likelihood(function_dist, *params)
---> 64 res = output.log_prob(target)
     65 res = self._add_other_terms(res, params)
     67 # Scale by the amount of data we have

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:193, in MultivariateNormal.log_prob(self, value)
    191 # Get log determininant and first part of quadratic form
    192 covar = covar.evaluate_kernel()
--> 193 inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
    195 res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
    196 return res

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1748, in LinearOperator.inv_quad_logdet(self, inv_quad_rhs, logdet, reduce_inv_quad)
   1745 if inv_quad_rhs is not None:
   1746     args = [inv_quad_rhs] + list(args)
-> 1748 preconditioner, precond_lt, logdet_p = self._preconditioner()
   1749 if precond_lt is None:
   1750     from ..operators.identity_linear_operator import IdentityLinearOperator

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/added_diag_linear_operator.py:126, in AddedDiagLinearOperator._preconditioner(self)
    124 if self._q_cache is None:
    125     max_iter = settings.max_preconditioner_size.value()
--> 126     self._piv_chol_self = self._linear_op.pivoted_cholesky(rank=max_iter)
    127     if torch.any(torch.isnan(self._piv_chol_self)).item():
    128         warnings.warn(
    129             "NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.",
    130             NumericalWarning,
    131         )

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1965, in LinearOperator.pivoted_cholesky(self, rank, error_tol, return_pivots)
   1944 r"""
   1945 Performs a partial pivoted Cholesky factorization of the (positive definite) LinearOperator.
   1946 :math:`\mathbf L \mathbf L^\top = \mathbf K`.
   (...)
   1962     https://www.sciencedirect.com/science/article/pii/S0168927411001814
   1963 """
   1964 func = PivotedCholesky.apply
-> 1965 res, pivots = func(self.representation_tree(), rank, error_tol, *self.representation())
   1967 if return_pivots:
   1968     return res, pivots

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs)
    503 if not torch._C._are_functorch_transforms_active():
    504     # See NOTE: [functorch vjp and autograd interaction]
    505     args = _functorch.utils.unwrap_dead_wrappers(args)
--> 506     return super().apply(*args, **kwargs)  # type: ignore[misc]
    508 if cls.setup_context == _SingleLevelFunction.setup_context:
    509     raise RuntimeError(
    510         'In order to use an autograd.Function with functorch transforms '
    511         '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
    512         'staticmethod. For more details, please see '
    513         'https://pytorch.org/docs/master/notes/extending.func.html')

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/functions/_pivoted_cholesky.py:78, in PivotedCholesky.forward(ctx, representation_tree, max_iter, error_tol, *matrix_args)
     75 # Populater L[... m:, m] with L[..., m:, m] * L[..., m, m].sqrt()
     76 if m + 1 < matrix_shape[-1]:
     77     # Get next row of the permuted matrix
---> 78     row = apply_permutation(matrix, pi_m.unsqueeze(-1), right_permutation=None).squeeze(-2)
     79     pi_i = permutation[..., m + 1 :].contiguous()
     81     L_m_new = row.gather(-1, pi_i)

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/utils/permutation.py:80, in apply_permutation(matrix, left_permutation, right_permutation)
     76     right_permutation = torch.arange(matrix.size(-1), device=matrix.device)
     78 # Apply permutations
     79 return to_dense(
---> 80     matrix.__getitem__(
     81         (
     82             *batch_idx,
     83             left_permutation.unsqueeze(-1),
     84             right_permutation.unsqueeze(-2),
     85         )
     86     )
     87 )

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:2847, in LinearOperator.__getitem__(self, index)
   2841 # Convert all indices into tensor indices
   2842 (
   2843     *new_batch_indices,
   2844     new_row_index,
   2845     new_col_index,
   2846 ) = _convert_indices_to_tensors(self, flattened_orig_indices)
-> 2847 res = self._get_indices(new_row_index, new_col_index, *new_batch_indices)
   2848 # Now un-flatten tensor indices
   2849 if len(tensor_index_shape) > 1:  # Do we need to unflatten?

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/cat_linear_operator.py:210, in CatLinearOperator._get_indices(self, row_index, col_index, *batch_indices)
    207 for linear_op_idx, sub_index in zip(linear_op_indices, sub_indices):
    208     sub_index[self.cat_dim] = sub_index[self.cat_dim] - self.cat_dim_cum_sizes[linear_op_idx]
--> 210 res_list = [
    211     linear_op._get_indices(sub_index[-2], sub_index[-1], *sub_index[:-2])
    212     for linear_op, sub_index in zip(linear_ops, sub_indices)
    213 ]
    214 if len(res_list) == 1:
    215     return res_list[0].view(target_shape).to(self.device)

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/cat_linear_operator.py:211, in <listcomp>(.0)
    207 for linear_op_idx, sub_index in zip(linear_op_indices, sub_indices):
    208     sub_index[self.cat_dim] = sub_index[self.cat_dim] - self.cat_dim_cum_sizes[linear_op_idx]
    210 res_list = [
--> 211     linear_op._get_indices(sub_index[-2], sub_index[-1], *sub_index[:-2])
    212     for linear_op, sub_index in zip(linear_ops, sub_indices)
    213 ]
    214 if len(res_list) == 1:
    215     return res_list[0].view(target_shape).to(self.device)

File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/dense_linear_operator.py:50, in DenseLinearOperator._get_indices(self, row_index, col_index, *batch_indices)
     48 def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
     49     # Perform the __getitem__
---> 50     res = self.tensor[(*batch_indices, row_index, col_index)]
     51     return res

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)

Expected Behavior

System information

Please complete the following information:

Additional context

Add any other context about the problem here.

JoachimSchaeffer commented 7 months ago

Is there any further information available on why this occurs and how it can be fixed? I ran into the same issue.