Open MoiseRousseau opened 1 year ago
New findings: I also get a similar error when slicing using boolean array and without using the CatLinearOperator, such as:
from linear_operator.operators import IdentityLinearOperator as Ops
N = 4
cond = [True,False,False,True]
ops = Ops(N)
print(ops.shape)
ops[:,cond]
Which gives:
torch.Size([4, 4])
Traceback (most recent call last):
File "/home/moise/Program/moise/linear_operator/debug.py", line 8, in <module>
ops[:,cond]
File "/home/moise/Program/moise/linear_operator/linear_operator/operators/_linear_operator.py", line 2870, in __getitem__
raise RuntimeError(
RuntimeError: IdentityLinearOperator.__getitem__ failed! Expected a final shape of size torch.Size([4, 4]), got torch.Size([4, 2]). This is a bug with LinearOperator, or your custom LinearOperator.
Looks like there may be a number of places where negative indexing isn't properly supported. I'll put up a fix for the CatLinearOperator
case, but this should probably be audited more comprehensively.
I also don't think we've given much though to supporting boolean indexing with linear_operator
- @gpleiss is that right?
Boolean indexing sounds tricky with linear operators. @MoiseRousseau do you have a good use case?
I found a workaround doing torch.argwhere(bool_array)
and then slice using the index. I was just reporting the error. Maybe this can be a way to implement it (even if this is suboptimal) ?
🐛 Bug
When slicing the CatLinearOperator using a negative index, the final shape of the slice does not match the expected shape and an error is returned. This fails at least for
ToeplitzLinearOperator
, theDiagLinearOperator
and theIdentityLinearOperator
.To reproduce
Code snippet to reproduce
Stack trace/error message
Expected Behavior
The slice behave as it is working when using positive indexes.
System information
LinearOperator Version 0.5.2 PyTorch Version 2.0.1 Ubuntu 22.04