f-dangel / curvlinops

PyTorch linear operators for curvature matrices (Hessian, Fisher/GGN, KFAC, ...)
https://curvlinops.readthedocs.io/en/latest/
MIT License
18 stars 8 forks source link

Unnecessary device transfer in `_postprocess` method #71

Open runame opened 10 months ago

runame commented 10 months ago

In the _postprocess method the result of a matrix-vector product will always be transferred to CPU. While this is consistent with the scipy interface, in many use cases where we only operate with torch GPU tensors this is not desirable, as it creates unnecessary overhead.

andres-fr commented 10 months ago

Chipping in to say that I am almost done writing a library to compute matrix-free SVD/eigh using random matrices, and this was one of my main issues with curvlinops: my library is pytorch "native" and does make use of the Hessian from curvlinops, but not of the scipy routines, so that step was also redundant.

A few thoughts:

  1. For most DL scenarios, the overhead is minimal with small matrices. And, in my case, if the matrices are big enough that they cause overhead, they are also big enough not to fit on GPU. So at the end this was a bit of a non-issue for me. This said, there are scenarios where this is definitely an issue.
  2. In case it helps someone, I dealt with it by writing a mixing class and inheriting from it. The use case is the following:
    • You have a routine that is pytorch-native
    • You have a curvlinops/scipy linear operator that expects numpy
    • you'd like to make it agnostic to numpy/pytorch input and output
  3. At the core of this issue is the fundamental question of, what is curvlinops doing? Originally, it was to enable scipy routines for DL. But as we develop our own routines, this may require reconsidering some fundamental aspects, and potentially redesign.
    • If curvlinops stays true to its origin, the wrapper class I mentioned in point 2 is IMO the way to go. Other
    • Alternatively, an extra API-layer can be added, where the "routines" are either defined (in pytorch or whatever), or imported (from scipy or whatever). Then there is an API-layer on top, which simply wraps the routines in an agnostic manner, using something like the wrapper I provide here.

Thoughts?

Snippet

class TorchLinOpWrapper:
    """
    Since this library operates mainly with PyTorch tensors, but some useful
    LinOps interface with NumPy arrays instead, this mixin class acts as a
    wraper on the ``__matmul__`` and ``__rmatmul__`` operators,
    so that the operator expects and returns torch tensors, even when the
    wrapped operator interfaces with NumPy. Usage example::

      # extend NumPy linear operator via multiple inheritance
      class TorchWrappedLinOp(TorchLinOpWrapper, LinOp):
          pass
      lop = TorchWrappedLinOp(...)  # instantiate normally
      w = lop @ v  # now v can be a PyTorch tensor
    """

    @staticmethod
    def _input_wrapper(x):
        """ """
        if isinstance(x, torch.Tensor):
            return x.cpu().numpy(), x.device
        else:
            return x, None

    @staticmethod
    def _output_wrapper(x, torch_device=None):
        """ """
        if torch_device is not None:
            return torch.from_numpy(x).to(torch_device)
        else:
            return x

    def __matmul__(self, x):
        """ """
        x, device = self._input_wrapper(x)
        result = self._output_wrapper(super().__matmul__(x), device)
        return result

    def __rmatmul__(self, x):
        """ """
        x, device = self._input_wrapper(x)
        result = self._output_wrapper(super().__rmatmul__(x), device)
        return result
f-dangel commented 10 months ago

This is a fair point, and I would be happy to offer a PyTorch-only mode of linear operators. The main focus of this library so far is bridging fancy methods for linear operators in scipy (such as eigsh) with torch curvature matrices. For many scenarios, I believe the GPU-CPU syncs should not be a bottleneck (e.g. the bottleneck will be multiplying with the Hessian on a larger data set). But of course, this in general depends on the use case.

I agree with @andres-fr that there should be a clean way to define a sub-class that works purely in PyTorch. One downside is that we will have to replicate most of the scipy.sparse.LinearOperator interface (or find an existing implementation for linear operators in PyTorch).

andres-fr commented 10 months ago

Actually I didn't mention that option, but now that you mention it's clear that it's the best one: extend the linops. What about making them numpy/pytorch agnostic? would that reduce code?

I think a good target would be to reach this agnosticity with minimal code overhead.

But we really want to avoid running into "tensorflow backend" problem, leading to an extremely messy API. Just a couple ideas to bounce