cornellius-gp / linear_operator

A LinearOperator implementation to wrap the numerical nuts and bolts of GPyTorch
MIT License
88 stars 26 forks source link

[Bug] torch.cat fails for linear operators #92

Open chrisyeh96 opened 6 months ago

chrisyeh96 commented 6 months ago

🐛 Bug

torch.cat fails for linear operators.

To reproduce

Code snippet to reproduce

from linear_operator.operators import DiagLinearOperator
import torch

D = DiagLinearOperator(torch.randn(2, 3, 100))  # Represents an operator of size 2 x 3 x 100
torch.cat([D, D], dim=-2)

Stack trace/error message

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ubuntu/miniconda3/envs/env/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py", line 2948, in __torch_function__
    raise NotImplementedError(f"torch.{name}({arg_classes}, {kwarg_classes}) is not implemented.")
NotImplementedError: torch.cat(list, dim=int) is not implemented.

Expected Behavior

According to the documentation, torch.cat should work on linear operators.

System information

Please complete the following information:

Balandat commented 5 months ago

Hmm interesting. Yeah not sure why the docs contain this, seems like this was never implemented. Support is there for many unary or binary operators, but torch.cat operates on a list of objects rather than a LinearOperator directly. @gpleiss have you considered this and similar operators in the past?

chrisyeh96 commented 5 months ago

There is a CatLinearOperator that implements what is needed. I wish that linear_operator could implement LinearOperator in a way such that calling torch.cat on a list of LinearOperator automatically creates a CatLinearOperator

Balandat commented 5 months ago

Yes, that makes a lot of sense and would be great to have. I'm not sure if that is easy to do with the __torch_function__ setup that we leverage for doing this dispatching under the hood. Let me see if I can get some intel on this.