cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch
MIT License
95 stars 28 forks source link

[Bug] SumBatchLinearOperator fails for high-order tensor #100

Open lmao14 opened 2 months ago

lmao14 commented 2 months ago

🐛 Bug

To reproduce

Code snippet to reproduce

import torch
import gpytorch
import linear_operator

kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([4, 3]),),
                                        batch_shape=torch.Size([4, 3]))
X = torch.randn([2, 5])
kxx = kern(X)

torch.Size([4, 3, 2, 2]) torch.Size([3, 2, 2]) torch.Size([4, 5, 5])

kern = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel(batch_shape=torch.Size([5, 4, 3]),),
                                        batch_shape=torch.Size([5, 4, 3]))
X = torch.randn([2, 5])
kxx = kern(X)

Stack trace/error message

RuntimeError                              Traceback (most recent call last)
Cell In[65], line 5
      3 X = torch.randn([2, 5])
      4 kxx = kern(X)
----> 5 print(kxx.sum(0).to_dense().shape)

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/, in LinearOperator.sum(self, dim)
   2515 # Otherwise: it's a batch dimension
   2516 elif dim < self.dim():
-> 2517     return self._sum_batch(dim)
   2518 else:
   2519     raise ValueError("Invalid dim ({}) for LinearOperator of size {}".format(orig_dim, self.shape))

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/, in LinearOperator._sum_batch(self, dim)
    850 """
    851 Sum the LinearOperator across a batch dimension (supplied as a positive number).
    857 :param dim: The (positive valued) dimension to sum
    858 """
    859 from linear_operator.operators.sum_batch_linear_operator import SumBatchLinearOperator
--> 861 return SumBatchLinearOperator(self, block_dim=dim)

File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File ~/miniconda3/lib/python3.8/site-packages/gpytorch/lazy/, in deprecated_lazy_tensor.<locals>.__init__(self, *args, **kwargs)
     43     else:
     44         new_kwargs[name] = val
---> 46 return __orig_init__(self, *args, **new_kwargs)

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/, in BlockLinearOperator.__init__(self, base_linear_op, block_dim)
     48 if block_dim != -3:
     49     positive_block_dim = base_linear_op.dim() + block_dim
---> 50     base_linear_op = base_linear_op._permute_batch(
     51         *range(positive_block_dim),
     52         *range(positive_block_dim + 1, base_linear_op.dim() - 2),
     53         positive_block_dim,
     54     )
     55 super(BlockLinearOperator, self).__init__(to_linear_operator(base_linear_op))
     56 self.base_linear_op = base_linear_op

File ~/miniconda3/lib/python3.8/site-packages/linear_operator/operators/, in LinearOperator._permute_batch(self, *dims)
    246 if torch.is_tensor(component):
    247     extra_dims = range(len(dims), component.dim())
--> 248     components.append(component.permute(*dims, *extra_dims))
    249 elif isinstance(component, LinearOperator):
    250     components.append(component._permute_batch(*dims))

RuntimeError: permute(sparse_coo): number of dimensions in the tensor input does not match the length of the desired ordering of dimensions i.e. input.dim() = 2 is not equal to len(dims) = 3

System information

Please complete the following information: