Open lmao14 opened 2 months ago
Code snippet to reproduce
import torch import gpytorch import linear_operator kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([4, 3]),), batch_shape=torch.Size([4, 3])) X = torch.randn([2, 5]) kxx = kern(X) print(kxx.shape) print(kxx.to_dense().sum(0).shape) print(kxx.sum(0).to_dense().shape)
torch.Size([4, 3, 2, 2]) torch.Size([3, 2, 2]) torch.Size([4, 5, 5])
kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([5, 4, 3]),), batch_shape=torch.Size([5, 4, 3])) X = torch.randn([2, 5]) kxx = kern(X) print(kxx.sum(0).to_dense().shape)
Stack trace/error message
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[65], line 5 3 X = torch.randn([2, 5]) 4 kxx = kern(X) ----> 5 print(kxx.sum(0).to_dense().shape) File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:2517, in LinearOperator.sum(self, dim) 2515 # Otherwise: it's a batch dimension 2516 elif dim < self.dim(): -> 2517 return self._sum_batch(dim) 2518 else: 2519 raise ValueError("Invalid dim ({}) for LinearOperator of size {}".format(orig_dim, self.shape)) File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:861, in LinearOperator._sum_batch(self, dim) 850 """ 851 Sum the LinearOperator across a batch dimension (supplied as a positive number). 852 (...) 857 :param dim: The (positive valued) dimension to sum 858 """ 859 from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator --> 861 return SumBatchLinearOperator(self, block_dim=dim) File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs) 43 else: 44 new_kwargs[name] = val ---> 46 return __orig_init__(self, *args, **new_kwargs) File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/lazy_tensor.py:46, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs) 43 else: 44 new_kwargs[name] = val ---> 46 return __orig_init__(self, *args, **new_kwargs) File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/block_linear_operator.py:50, in BlockLinearOperator.__init__(self, base_linear_op, block_dim) 48 if block_dim != -3: 49 positive_block_dim = base_linear_op.dim() + block_dim ---> 50 base_linear_op = base_linear_op._permute_batch( 51 *range(positive_block_dim), 52 *range(positive_block_dim + 1, base_linear_op.dim() - 2), 53 positive_block_dim, 54 ) 55 super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op)) 56 self.base_linear_op = base_linear_op File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/_linear_operator.py:248, in LinearOperator._permute_batch(self, *dims) 246 if torch.is_tensor(component): 247 extra_dims = range(len(dims), component.dim()) --> 248 components.append(component.permute(*dims, *extra_dims)) 249 elif isinstance(component, LinearOperator): 250 components.append(component._permute_batch(*dims)) RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3
Please complete the following information:
🐛 Bug
To reproduce
Code snippet to reproduce
torch.Size([4, 3, 2, 2]) torch.Size([3, 2, 2]) torch.Size([4, 5, 5])
Stack trace/error message
System information
Please complete the following information: