Open XiankangTang opened 10 months ago
I ran the code from the example that uses multiple GPUs. https://github.com/cornellius-gp/gpytorch/blob/master/examples/02_Scalable_Exact_GPs/Simple_MultiGPU_GP_Regression.ipynb I used 1 A100 and 3 RTX 3090's to run this string of code, but the programme is buggy.
Code snippet to reproduce
n_devices = 4 print('Planning to run on {} GPUs.'.format(n_devices))
Stack trace/error message
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[8], line 1 ----> 1 model, likelihood = train(train_x, train_y, 2 n_devices=n_devices, output_device=output_device, 3 preconditioner_size=100, 4 n_training_iter=20) Cell In[7], line 43, in train(train_x, train_y, n_devices, output_device, preconditioner_size, n_training_iter) 40 loss = -mll(output, train_y) 41 return loss ---> 43 loss = closure() 44 loss.backward() 46 for i in range(n_training_iter): Cell In[7], line 40, in train.<locals>.closure() 38 output = model(train_x) 39 print(output) ---> 40 loss = -mll(output, train_y) 41 return loss File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/gpytorch/module.py:31, in Module.__call__(self, *inputs, **kwargs) 30 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]: ---> 31 outputs = self.forward(*inputs, **kwargs) 32 if isinstance(outputs, list): 33 return [_validate_module_outputs(output) for output in outputs] File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params) 62 # Get the log prob of the marginal distribution 63 output = self.likelihood(function_dist, *params) ---> 64 res = output.log_prob(target) 65 res = self._add_other_terms(res, params) 67 # Scale by the amount of data we have File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:193, in MultivariateNormal.log_prob(self, value) 191 # Get log determininant and first part of quadratic form 192 covar = covar.evaluate_kernel() --> 193 inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True) 195 res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)]) 196 return res File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1748, in LinearOperator.inv_quad_logdet(self, inv_quad_rhs, logdet, reduce_inv_quad) 1745 if inv_quad_rhs is not None: 1746 args = [inv_quad_rhs] + list(args) -> 1748 preconditioner, precond_lt, logdet_p = self._preconditioner() 1749 if precond_lt is None: 1750 from ..operators.identity_linear_operator import IdentityLinearOperator File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/added_diag_linear_operator.py:126, in AddedDiagLinearOperator._preconditioner(self) 124 if self._q_cache is None: 125 max_iter = settings.max_preconditioner_size.value() --> 126 self._piv_chol_self = self._linear_op.pivoted_cholesky(rank=max_iter) 127 if torch.any(torch.isnan(self._piv_chol_self)).item(): 128 warnings.warn( 129 "NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.", 130 NumericalWarning, 131 ) File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1965, in LinearOperator.pivoted_cholesky(self, rank, error_tol, return_pivots) 1944 r""" 1945 Performs a partial pivoted Cholesky factorization of the (positive definite) LinearOperator. 1946 :math:`\mathbf L \mathbf L^\top = \mathbf K`. (...) 1962 https://www.sciencedirect.com/science/article/pii/S0168927411001814 1963 """ 1964 func = PivotedCholesky.apply -> 1965 res, pivots = func(self.representation_tree(), rank, error_tol, *self.representation()) 1967 if return_pivots: 1968 return res, pivots File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/torch/autograd/function.py:506, in Function.apply(cls, *args, **kwargs) 503 if not torch._C._are_functorch_transforms_active(): 504 # See NOTE: [functorch vjp and autograd interaction] 505 args = _functorch.utils.unwrap_dead_wrappers(args) --> 506 return super().apply(*args, **kwargs) # type: ignore[misc] 508 if cls.setup_context == _SingleLevelFunction.setup_context: 509 raise RuntimeError( 510 'In order to use an autograd.Function with functorch transforms ' 511 '(vmap, grad, jvp, jacrev, ...), it must override the setup_context ' 512 'staticmethod. For more details, please see ' 513 'https://pytorch.org/docs/master/notes/extending.func.html') File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/functions/_pivoted_cholesky.py:78, in PivotedCholesky.forward(ctx, representation_tree, max_iter, error_tol, *matrix_args) 75 # Populater L[... m:, m] with L[..., m:, m] * L[..., m, m].sqrt() 76 if m + 1 < matrix_shape[-1]: 77 # Get next row of the permuted matrix ---> 78 row = apply_permutation(matrix, pi_m.unsqueeze(-1), right_permutation=None).squeeze(-2) 79 pi_i = permutation[..., m + 1 :].contiguous() 81 L_m_new = row.gather(-1, pi_i) File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/utils/permutation.py:80, in apply_permutation(matrix, left_permutation, right_permutation) 76 right_permutation = torch.arange(matrix.size(-1), device=matrix.device) 78 # Apply permutations 79 return to_dense( ---> 80 matrix.__getitem__( 81 ( 82 *batch_idx, 83 left_permutation.unsqueeze(-1), 84 right_permutation.unsqueeze(-2), 85 ) 86 ) 87 ) File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:2847, in LinearOperator.__getitem__(self, index) 2841 # Convert all indices into tensor indices 2842 ( 2843 *new_batch_indices, 2844 new_row_index, 2845 new_col_index, 2846 ) = _convert_indices_to_tensors(self, flattened_orig_indices) -> 2847 res = self._get_indices(new_row_index, new_col_index, *new_batch_indices) 2848 # Now un-flatten tensor indices 2849 if len(tensor_index_shape) > 1: # Do we need to unflatten? File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/cat_linear_operator.py:210, in CatLinearOperator._get_indices(self, row_index, col_index, *batch_indices) 207 for linear_op_idx, sub_index in zip(linear_op_indices, sub_indices): 208 sub_index[self.cat_dim] = sub_index[self.cat_dim] - self.cat_dim_cum_sizes[linear_op_idx] --> 210 res_list = [ 211 linear_op._get_indices(sub_index[-2], sub_index[-1], *sub_index[:-2]) 212 for linear_op, sub_index in zip(linear_ops, sub_indices) 213 ] 214 if len(res_list) == 1: 215 return res_list[0].view(target_shape).to(self.device) File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/cat_linear_operator.py:211, in <listcomp>(.0) 207 for linear_op_idx, sub_index in zip(linear_op_indices, sub_indices): 208 sub_index[self.cat_dim] = sub_index[self.cat_dim] - self.cat_dim_cum_sizes[linear_op_idx] 210 res_list = [ --> 211 linear_op._get_indices(sub_index[-2], sub_index[-1], *sub_index[:-2]) 212 for linear_op, sub_index in zip(linear_ops, sub_indices) 213 ] 214 if len(res_list) == 1: 215 return res_list[0].view(target_shape).to(self.device) File ~/anaconda3/envs/drprobe/lib/python3.11/site-packages/linear_operator/operators/dense_linear_operator.py:50, in DenseLinearOperator._get_indices(self, row_index, col_index, *batch_indices) 48 def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor: 49 # Perform the __getitem__ ---> 50 res = self.tensor[(*batch_indices, row_index, col_index)] 51 return res RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)
Please complete the following information:
Add any other context about the problem here.
Is there any further information available on why this occurs and how it can be fixed? I ran into the same issue.
🐛 Bug
I ran the code from the example that uses multiple GPUs. https://github.com/cornellius-gp/gpytorch/blob/master/examples/02_Scalable_Exact_GPs/Simple_MultiGPU_GP_Regression.ipynb I used 1 A100 and 3 RTX 3090's to run this string of code, but the programme is buggy.
To reproduce
Code snippet to reproduce
Stack trace/error message
Expected Behavior
System information
Please complete the following information:
Additional context
Add any other context about the problem here.