albanD / subclass_zoo

144 stars 24 forks source link

How to retain the grad of via __torch_dispatch__ for torch.Tensor method #29

Open wang-chen opened 2 years ago

wang-chen commented 2 years ago

I have a question, which might be very simple, but I have no idea how to fix it.

I am trying to subclass a torch.Tensor, and want to retain the grad of the original torch.Tensor method.

Here is my code:

import torch
from torch.utils._pytree import tree_map

class MyTensor(torch.Tensor):

    @staticmethod
    def __new__(cls, tensor):
        return torch.Tensor.as_subclass(tensor, cls)

    def __init__(self, tensor):
        self.tensor = tensor

    __torch_function__ = torch._C._disabled_torch_function_impl

    def __repr__(self):
        return self.__class__.__name__ +':\n'+ self.tensor.__repr__()

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            return t.tensor if isinstance(t, cls) else t

        def wrap(t):
            return cls(t) if isinstance(t, torch.Tensor) and not isinstance(t, cls) else t

        return tree_map(wrap, (super().__torch_dispatch__(func, types, args, kwargs)))

    def my_method(self):
        return self.tensor.exp()

if I use __torch_function__, it can retain the grad. How can I retain the grad by using __torch_dispatch__? Thank you so much!

albanD commented 2 years ago

Hi,

When you do x.exp().sum().backward(), do you expect x.grad to be populated? Or x.tensor.grad to be populated?

wang-chen commented 2 years ago

I will prefer x.grad, but I want to know how can I do that for both cases since they both might be useful in the future.

albanD commented 2 years ago

The high level idea is that you have to choose. Either x gets autograd or x.tensor. But it can't be both.

Here is an extension to your script to show how to do some if these things:

import torch
from torch.utils._pytree import tree_map

class MyTensorWithGrad(torch.Tensor):
    @staticmethod
    def __new__(cls, tensor, *, requires_grad=False):
        assert tensor.requires_grad == False, "Only the wrapper should require gradients"
        return torch.Tensor._make_subclass(cls, tensor, require_grad=requires_grad)

    def __init__(self, tensor, *, requires_grad=False):
        self.tensor = tensor

    __torch_function__ = torch._C._disabled_torch_function_impl

    def __repr__(self):
        autograd_info = f"grad_fn={self.grad_fn}" if self.grad_fn else \
            f"requires_grad={self.requires_grad}"
        return f"{self.__class__.__name__}({self.tensor.__repr__()}, {autograd_info})"

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            return t.tensor if isinstance(t, cls) else t

        def wrap(t):
            return cls(t) if isinstance(t, torch.Tensor) and not isinstance(t, cls) else t

        return tree_map(wrap, (super().__torch_dispatch__(func, types, args, kwargs)))

    def my_method(self):
        # This method lives "above" autograd, should we should NOT access the ".tensor"
        # attribute that is not differentiable.
        # Use a custom Function to make this differentiable
        class MyMethod(torch.autograd.Function):
            @staticmethod
            def forward(ctx, inp):
                # here it is ok to access tensor in a non-differentiable way!
                ctx.save_for_backward(inp)
                return inp.tensor.exp()
            @staticmethod
            def backward(ctx, gO):
                inp, = ctx.saved_tensors
                return inp * gO
        return MyMethod.apply(self)

    # if you don't want to have to write custom Function for everything, 
    # you can create a way to get the `.tensor` in a differentiable way!
    # similar to .values() on sparse Tensor
    def get_tensor_attr(self):
        class MyAccessor(torch.autograd.Function):
            @staticmethod
            def forward(ctx, inp):
                return inp.tensor
            @staticmethod
            def backward(ctx, gO):
                return gO
        return MyAccessor.apply(self)

    def my_other_method(self):
        return self.get_tensor_attr().exp()

x = MyTensorWithGrad(torch.randn(3), requires_grad=True)
print(x)

print("my_method")
print(x.my_method())

print("tensor")
print(x.tensor)

print("get_tensor_attr")
print(x.get_tensor_attr())

print("my_other_method")
print(x.my_other_method())
wang-chen commented 2 years ago

Thank you so much for your reply. Things are a little bit complicated on our side.

We are developing an open-source project PyPose using PyTorch and are subclassing torch.Tensor to represent Lie Algebra and Lie Group.

One of our developers has asked a question 712 regarding using vmap and jacrev to compute Jacobian. Previously, we use torch_function for subclassing, after seeing your reply, we are considering using torch_dispatch, but we are not sure how can we handle it.

Basically, we have the following objective.

You can see our current implementation here.

Any suggestions for this? Thank you so much!

For your questions on our use case @zou3519 , you can also refer to the above link.

albanD commented 2 years ago

the subclass constructor needs to retain the gradient, so in new(), we use torch.Tensor.as_subclass(tensor, cls) instead of torch.Tensor._make_subclass(cls, tensor), since the input tensor can be the output of a neural network, which needs to track the grad for training.

This one can be done in a similar way as the "differentiable accessor" above but by doing a "differentiable constructor":

# Rest of the class from above
    @staticmethod
    def from_tensor(t):
        class MyConst(torch.autograd.Function):
            @staticmethod
            def forward(ctx, t):
                return MyTensorWithGrad(t.detach())
            @staticmethod
            def backward(ctx, gO):
                return gO
        return MyConst.apply(t)

inp = torch.rand(3, requires_grad=True)
x = MyTensorWithGrad.from_tensor(inp)
print(x)

use vmap and jacrev to compute the Jacobian matrix.

How big are these Tensors? You can use vanilla pytorch functions to get the jacobian as well, you can do so via torch.autograd.functional.jacobian.