Open runame opened 10 months ago
Chipping in to say that I am almost done writing a library to compute matrix-free SVD/eigh using random matrices, and this was one of my main issues with curvlinops: my library is pytorch "native" and does make use of the Hessian from curvlinops, but not of the scipy routines, so that step was also redundant.
A few thoughts:
Thoughts?
class TorchLinOpWrapper:
"""
Since this library operates mainly with PyTorch tensors, but some useful
LinOps interface with NumPy arrays instead, this mixin class acts as a
wraper on the ``__matmul__`` and ``__rmatmul__`` operators,
so that the operator expects and returns torch tensors, even when the
wrapped operator interfaces with NumPy. Usage example::
# extend NumPy linear operator via multiple inheritance
class TorchWrappedLinOp(TorchLinOpWrapper, LinOp):
pass
lop = TorchWrappedLinOp(...) # instantiate normally
w = lop @ v # now v can be a PyTorch tensor
"""
@staticmethod
def _input_wrapper(x):
""" """
if isinstance(x, torch.Tensor):
return x.cpu().numpy(), x.device
else:
return x, None
@staticmethod
def _output_wrapper(x, torch_device=None):
""" """
if torch_device is not None:
return torch.from_numpy(x).to(torch_device)
else:
return x
def __matmul__(self, x):
""" """
x, device = self._input_wrapper(x)
result = self._output_wrapper(super().__matmul__(x), device)
return result
def __rmatmul__(self, x):
""" """
x, device = self._input_wrapper(x)
result = self._output_wrapper(super().__rmatmul__(x), device)
return result
This is a fair point, and I would be happy to offer a PyTorch-only mode of linear operators.
The main focus of this library so far is bridging fancy methods for linear operators in scipy
(such as eigsh
) with torch
curvature matrices. For many scenarios, I believe the GPU-CPU syncs should not be a bottleneck (e.g. the bottleneck will be multiplying with the Hessian on a larger data set). But of course, this in general depends on the use case.
I agree with @andres-fr that there should be a clean way to define a sub-class that works purely in PyTorch. One downside is that we will have to replicate most of the scipy.sparse.LinearOperator
interface (or find an existing implementation for linear operators in PyTorch).
Actually I didn't mention that option, but now that you mention it's clear that it's the best one: extend the linops. What about making them numpy/pytorch agnostic? would that reduce code?
I think a good target would be to reach this agnosticity with minimal code overhead.
But we really want to avoid running into "tensorflow backend" problem, leading to an extremely messy API. Just a couple ideas to bounce
In the
_postprocess
method the result of a matrix-vector product will always be transferred to CPU. While this is consistent with the scipy interface, in many use cases where we only operate with torch GPU tensors this is not desirable, as it creates unnecessary overhead.