cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch
MIT License
95 stars 28 forks source link

[Bug] Failing to slice the CatLinearOperator when indexes are negative or when using boolean array #79

Open MoiseRousseau opened 1 year ago

MoiseRousseau commented 1 year ago

🐛 Bug

When slicing the CatLinearOperator using a negative index, the final shape of the slice does not match the expected shape and an error is returned. This fails at least for ToeplitzLinearOperator, the DiagLinearOperator and the IdentityLinearOperator.

To reproduce

Code snippet to reproduce

from linear_operator.operators import IdentityLinearOperator as Ops
from linear_operator.operators import cat as cat_ops

N = 8
base = cat_ops([Ops(N) for _ in range(2)], dim=1)
print(base.shape) #should be 8,16
print(base[:,3:base.shape[-1]-3].shape) #should be 8,10
print(base[:,3:-3].shape) #fail...

Stack trace/error message

torch.Size([8, 16])
torch.Size([8, 10])
Traceback (most recent call last):
  File "/home/moise/Program/moise/linear_operator/debug.py", line 8, in <module>
    print(base[:,3:-3].shape) #fail...
  File "/home/moise/Program/moise/linear_operator/linear_operator/operators/_linear_operator.py", line 2870, in __getitem__
    raise RuntimeError(
RuntimeError: CatLinearOperator.__getitem__ failed! Expected a final shape of size torch.Size([8, 10]), got torch.Size([8, 5]). This is a bug with LinearOperator, or your custom LinearOperator.

Expected Behavior

The slice behave as it is working when using positive indexes.

System information

LinearOperator Version 0.5.2 PyTorch Version 2.0.1 Ubuntu 22.04

MoiseRousseau commented 1 year ago

New findings: I also get a similar error when slicing using boolean array and without using the CatLinearOperator, such as:

from linear_operator.operators import IdentityLinearOperator as Ops

N = 4
cond = [True,False,False,True]
ops = Ops(N)
print(ops.shape)
ops[:,cond]

Which gives:

torch.Size([4, 4])
Traceback (most recent call last):
  File "/home/moise/Program/moise/linear_operator/debug.py", line 8, in <module>
    ops[:,cond]
  File "/home/moise/Program/moise/linear_operator/linear_operator/operators/_linear_operator.py", line 2870, in __getitem__
    raise RuntimeError(
RuntimeError: IdentityLinearOperator.__getitem__ failed! Expected a final shape of size torch.Size([4, 4]), got torch.Size([4, 2]). This is a bug with LinearOperator, or your custom LinearOperator.
Balandat commented 1 year ago

Looks like there may be a number of places where negative indexing isn't properly supported. I'll put up a fix for the CatLinearOperator case, but this should probably be audited more comprehensively.

I also don't think we've given much though to supporting boolean indexing with linear_operator - @gpleiss is that right?

gpleiss commented 1 year ago

Boolean indexing sounds tricky with linear operators. @MoiseRousseau do you have a good use case?

MoiseRousseau commented 1 year ago

I found a workaround doing torch.argwhere(bool_array) and then slice using the index. I was just reporting the error. Maybe this can be a way to implement it (even if this is suboptimal) ?