pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.4k stars 102 forks source link

Can't vmap(jacrev(slogdet)) #296

Open AlphaBetaGamma96 opened 2 years ago

AlphaBetaGamma96 commented 2 years ago

Hi All,

I've been trying to get per-sample gradients of a simple feed-forward network that contains a torch.slogdet as its final layer. When I go to apply vmap to jacrev the number of arguments of the function changes and I have no idea why. I followed the example of here with vmap and jacrev yet it seems to fail for some reason.

The minimal reproduction example is here

import torch
from functorch import make_functional, vmap, grad, jacrev
import torch.nn as nn

class model(nn.Module):

  def __init__(self, num_input, num_hidden):
    super(model, self).__init__()

    self.num_input = num_input
    self.num_hidden = num_hidden

    self.layer1 = nn.Linear(2, num_hidden)
    self.layer2 = nn.Linear(num_hidden, num_input)

  def forward(self, x):
    g = x.mean(dim=0, keepdim=True).repeat(self.num_input, 1)
    f = torch.cat((x,g), dim=1)

    y1 = self.layer1(f)
    y2 = self.layer2(y1)

    sgn, logabs = torch.slogdet(y2)
    return sgn, logabs    

net = model(2, 32)

func_model, params = make_functional(net)

def loss(params, x):
  sgn, logabs = func_model(params, x)
  return logabs

x = torch.randn(4096, 2, 1)

per_sample_grads = vmap(jacrev(loss), (None, 0))(params, x)

The Traceback is below,

Traceback (most recent call last):
  File "jacrev_fail_example.py", line 36, in <module>
    per_sample_grads = vmap(jacrev(loss), (None, 0))(params, x)
  File "~/.local/lib/python3.8/site-packages/functorch/_src/vmap.py", line 319, in wrapped
    batched_outputs = func(*batched_inputs, **kwargs)
  File "~/.local/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 358, in wrapper_fn
    output, vjp_fn = vjp(f_wrapper, *primals)
  File "~/.local/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 229, in vjp
    primals_out = f(*diff_primals)
  File "~/.local/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 407, in f_wrapper
    replaced_args = _replace_args(args, wrapper_args, argnums)
  File "~/.local/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 380, in _replace_args
    raise RuntimeError(f'new_args should be of size 1, was of size {len(new_args)}')
RuntimeError: new_args should be of size 1, was of size 4
zou3519 commented 2 years ago

This is the same issue as #292 I think. I'm planning to put up a fix for that later today

zou3519 commented 2 years ago

Okay, I put up a fix for the specific error you saw. There's another issue though, which is that the implementation of slogdet backward calls tensor.item() and we are unable to vmap over it:

import torch
from functorch import make_functional, vmap, grad, jacrev
import torch.nn as nn

class model(nn.Module):

  def __init__(self, num_input, num_hidden):
    super(model, self).__init__()

    self.num_input = num_input
    self.num_hidden = num_hidden

    self.layer1 = nn.Linear(2, num_hidden)
    self.layer2 = nn.Linear(num_hidden, num_input)

  def forward(self, x):
    g = x.mean(dim=0, keepdim=True).repeat(self.num_input, 1)
    f = torch.cat((x,g), dim=1)

    y1 = self.layer1(f)
    y2 = self.layer2(y1)

    sgn, logabs = torch.slogdet(y2)
    return sgn, logabs    

net = model(2, 32)

func_model, params = make_functional(net)

def loss(params, x):
  sgn, logabs = func_model(params, x)
  return logabs

x = torch.randn(4096, 2, 1)

per_sample_grads = vmap(jacrev(loss), (None, 0))(params, x)
# RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item()
# on a Tensor, please try to rewrite what you're doing with other operations.
AlphaBetaGamma96 commented 2 years ago

Yes, after posting I did notice that it looked pretty similar to #292

I only used torch.slogdet within this example as a simplified version of my actual network which has a custom.autograd.Function version of torch.slogdet. As I manually defined the backward there, it shouldn't call .item()?

Should I try the following fix with my custom torch.slogdet? Or would there be another issue with using custom functions?

zou3519 commented 2 years ago

autograd.Function is silently incorrect when used with functorch (https://github.com/pytorch/functorch/issues/207). So even if you manually define the backward, functorch just ignores it right now :(

That issue is very tricky to fix, but it's pretty high priority on our radar

AlphaBetaGamma96 commented 2 years ago

Hi @zou3519,

I think I have a solution for removing .item from slogdet_backward. The only problem is it requires a batching rule for at::equal. The backward I've just written is below,

Tensor slogdet_backward(const Tensor& grad_logabsdet,
                        const Tensor& self,
                        const Tensor& signdet, const Tensor& logabsdet) {
  auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
    Tensor u, sigma, vh;
    std::tie(u, sigma, vh) = at::linalg_svd(self, false);
    Tensor v = vh.mH();
    // sigma has all non-negative entries (also with at least one zero entry)
    // so logabsdet = \sum log(abs(sigma))
    // but det = 0, so backward logabsdet = \sum log(sigma)
    auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma);
    return svd_backward({}, gsigma, {}, u, sigma, vh);
  };

  auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
    // TODO: replace self.inverse with linalg_inverse
    return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().mH();
  };

  if (self.dim() == 2) {
    at::Tensor sing = at::zeros(1, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 
    //bool is_singular = self.is_complex() ? signdet.abs().item<double>() == 0 : signdet.item<double>() == 0; //removed
    bool is_singular = self.is_complex() ? at::equal(signdet.abs(), sing) : at::equal(signdet, sing); //added
    if (is_singular) {
      return singular_case_backward(grad_logabsdet, self);
    } else {
      return nonsingular_case_backward(grad_logabsdet, self);
    }
   } else {
    auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(self.is_complex() ? at::where(signdet.abs()) : at::where(signdet));
    c10::optional<Tensor> first_nonzero_signdet_index = nonzero_signdet_indices[0];

    if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) {  // all log determinants are finite (non-singular)
     return nonsingular_case_backward(grad_logabsdet, self);
    }

    auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0));
    c10::optional<Tensor> first_zero_signdet_index = zero_signdet_indices[0];

    if (first_zero_signdet_index->size(0) == logabsdet.numel()) {  // all log determinants are -inf (singular)
      return singular_case_backward(grad_logabsdet, self);
    }

    Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);

    // invertible case
    grad_slogdet.index_put_(/*indices=*/nonzero_signdet_indices,
                          //NOLINTNEXTLINE(bugprone-argument-comment)
                          /*value=*/nonsingular_case_backward(grad_logabsdet.index(nonzero_signdet_indices),
                           self.index(nonzero_signdet_indices)));

    // non-invertible case, uses SVD
    grad_slogdet.index_put_(/*indices=*/zero_signdet_indices,
                           // NOLINTNEXTLINE(bugprone-argument-comment)
                           /*value=*/singular_case_backward(grad_logabsdet.index(zero_signdet_indices),
                                                            self.index(zero_signdet_indices)));

    return grad_slogdet;
  }
}

It works for calculating jacrev of torch.linalg.slogdet, however, it runs into the aforementioned BatchingRule error when using vmap. Here's an example:

import torch
from functorch import vmap, jacrev

def logabs(x):
  return torch.linalg.slogdet(x)[1]

x = torch.randn(100,4,4)
jacobian = jacrev(logabs)(x) #return Tensor of shape [100,100,4,4]

torch.manual_seed(0)

x = torch.randn(4,4)
jacobian = jacrev(logabs)(x) #returns Tensor of shape [4,4]
"""
tensor([[-0.6138, -0.1843,  0.1957, -0.3358],
        [ 0.2193, -0.0545, -0.0507, -0.3950],
        [ 0.4635, -0.5424,  0.5355, -0.0715],
        [-0.2209,  0.1132,  0.7572, -0.1647]])
"""

EDIT: (helps if I share the error)

>>> vmap(jacrev(logabs), in_dims=(0))(x)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "~/.local/lib/python3.9/site-packages/functorch/_src/vmap.py", line 383, in wrapped
    batched_outputs = func(*batched_inputs, **kwargs)
  File "~/.local/lib/python3.9/site-packages/functorch/_src/eager_transforms.py", line 440, in wrapper_fn
    results = vmap(vjp_fn)(basis)
  File "~/.local/lib/python3.9/site-packages/functorch/_src/vmap.py", line 383, in wrapped
    batched_outputs = func(*batched_inputs, **kwargs)
  File "~/.local/lib/python3.9/site-packages/functorch/_src/eager_transforms.py", line 288, in wrapper
    result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
  File "~/.local/lib/python3.9/site-packages/functorch/_src/eager_transforms.py", line 99, in _autograd_grad
    grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
  File "~/pytorch_from_source/pytorch/torch/autograd/__init__.py", line 275, in grad
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Batching rule not implemented for aten::equal. We could not generate a fallback.
zou3519 commented 2 years ago

@AlphaBetaGamma96 -- are you still looking to use functorch on a custom autograd.Function? I have a prototype up and running and am looking for test cases :), I remember you had a custom determinant function

AlphaBetaGamma96 commented 2 years ago

Hi @zou3519, I indeed am! Shall I upgrade to the latest nightly?

zou3519 commented 2 years ago

Thanks for the quick reply! The prototype is on a branch, we haven't shipped it to nightlies yet. Would it be possible to see the autograd.Function you were interested in using vmap with (and some sample inputs?)

AlphaBetaGamma96 commented 2 years ago

Hi @zou3519, I do have an example in mind. I'll have a look for it and share a minimal reproducible example below as soon as I find it!

AlphaBetaGamma96 commented 2 years ago

Hi @zou3519,

Sorry for this being quite late, but I had to rewrite it. I've attached an example below, which is run with the current versions of torch = 1.13.0.dev20221003 and CUDA = 11.6.

In short, it's a custom torch.linalg.det which can handle singular matrices via using a SVD-decomposition within its gradient. I do remember with an older version of PyTorch, torch.linalg.det would fail with singular matrices but it seems to be fine with them now. So, perhaps the use-case of this example has been made redundant with recent updates, however, it's still a nice example.

My main goal would be to vmap over a more complicated function (example on my github here) to compute the Laplacian of the determinant with respect to its input. Although that's a little more complicated!

I can add a custom torch.linalg.det function with a custom numerically stable DoubleBackward if you want to test nested torch.autograd.Function too, just let me know!

import torch
from torch import Tensor
_ = torch.manual_seed(0)

import functorch
from functorch import vmap, jacrev

#helper functions here

#check of https://www.tensorflow.org/api_docs/python/tf/math/cumsum
def exclusive_cumsum(x: torch.Tensor, dim: int):
  res = x.cumsum(dim=dim).roll(1) #roll(1) and res[...,0]=0. make it exclusive btw
  res[...,0]=0.
  return res

#not the same as exclusive_cumsum().flip() 
#see here: https://www.tensorflow.org/api_docs/python/tf/math/cumsum
def reverse_exclusive_cumsum(x: torch.Tensor, dim: int):
  res = x.flip(dim).cumsum(dim=dim)
  res[...,-1]=0.
  res=res.roll(1).flip(-1) 
  return res

def get_log_sigma(sigma: torch.Tensor):
  return torch.log(sigma)

#modified from https://github.com/deepmind/ferminet/blob/tf/ferminet/networks.py#L143-L192  
def get_log_gamma(log_sigma: torch.Tensor):
  lower = exclusive_cumsum(log_sigma, dim=-1)
  upper = reverse_exclusive_cumsum(log_sigma, dim=-1)
  log_gamma = lower + upper

  return log_gamma

class CustomDeterminant(torch.autograd.Function):

  @staticmethod
  def forward(ctx, matrix):
    U, S, VT = torch.linalg.svd(matrix)
    detU = torch.linalg.det(U)
    detVT = torch.linalg.det(VT)

    ctx.save_for_backward(U, S, VT, detU, detVT)

    det =  detU * detVT * torch.linalg.det(torch.diag_embed(S))

    return det

  @staticmethod
  def backward(ctx, DetBar):
    U, S, VT, detU, detVT = ctx.saved_tensors

    ln_S = get_log_sigma(S)
    gamma = torch.diag_embed( torch.exp(get_log_gamma(ln_S)) )

    Abar = DetBar * detU * detVT * U @ gamma @ VT

    return Abar

def pytorch_det(matrix: Tensor) -> Tensor:
  return torch.linalg.det(matrix)

def custom_det(matrix: Tensor) -> Tensor:
  return CustomDeterminant.apply(matrix)

#====================================================================================#

N = 4
matrix = torch.randn(N, N, dtype=torch.float64, requires_grad=True)

backward_check = torch.autograd.gradcheck(func=custom_det, inputs=(matrix), raise_exception=False)

print("grad check: ",backward_check)

naive_jac = torch.autograd.functional.jacobian(func=pytorch_det, inputs=(matrix))
custom_jac = torch.autograd.functional.jacobian(func=custom_det, inputs=(matrix))

print("\nCheck custom function vs naive function")
print("For larger matrices the difference will become more apparent\n")

for i in range(N):
  print("Jacobian check: ", i, torch.allclose(naive_jac[i], custom_jac[i]))
  print(naive_jac[i], custom_jac[i], "\n")

#====================================================================================#

which outputs,

grad check:  True

Check custom function vs naive function
For larger matrices the difference will become more apparent

Jacobian check:  0 True
tensor([ 1.4697,  1.2719, -6.0033,  0.9463], dtype=torch.float64) tensor([ 1.4697,  1.2719, -6.0033,  0.9463], dtype=torch.float64) 

Jacobian check:  1 True
tensor([-3.6909, -3.2754,  8.0690, -1.1910], dtype=torch.float64) tensor([-3.6909, -3.2754,  8.0690, -1.1910], dtype=torch.float64) 

Jacobian check:  2 True
tensor([-2.3848, -1.7151,  6.2400,  0.4697], dtype=torch.float64) tensor([-2.3848, -1.7151,  6.2400,  0.4697], dtype=torch.float64) 

Jacobian check:  3 True
tensor([ 0.4287, -1.9187,  1.1866,  1.5338], dtype=torch.float64) tensor([ 0.4287, -1.9187,  1.1866,  1.5338], dtype=torch.float64) 
zou3519 commented 2 years ago

No need to apologize and thank you for the example! Time to test it out and see how it goes

AlphaBetaGamma96 commented 1 year ago

Hi @zou3519, any update on this?

zou3519 commented 1 year ago

Hey @AlphaBetaGamma96,

Sorry for the delayed reply, I was out on vacation last week. I haven't gotten around to testing your example, but here's a quick update:

There are two types of autograd.Function we want to add support for for functorch:

  1. Users should be able to define a "vmap" rule on an autograd.Function. This is useful if the autograd.Function does not use PyTorch operations and instead calls some C++/CUDA code.
  2. Users should be able to take an autograd.Function that uses only PyTorch operations and use it with functorch transforms without defining a vmap rule.

Your use case is (2).

Implementing (1) is on the way to implementing (2). I'm currently working through shipping 1 to PyTorch (some work is being done over at https://github.com/pytorch/pytorch/pull/88785) and (2) is coming soon after. The goal is to complete this work before the end of the year.

zou3519 commented 1 year ago

@AlphaBetaGamma96 -- I've been testing your example and everything looks good so far. Here's a quick summary if you're interested in trying it out more:

Would be curious to hear your feedback before we officially release it in the next release. Documentation will be coming soon, but happy to answer questions

AlphaBetaGamma96 commented 1 year ago

Hi @zou3519, apologies for the late response. The refactor of the example looks great! Can I ask a few questions?

  1. Is the feature flag represented as the class attribute generate_vmap_rule=True?

  2. For my more complicated example (to represent a 2nd derivative), is it easy to nest this for a manually defined 2nd derivative? I.e., just defined another torch.autograd.Function object in the backward method of the first function, like previous implementations but just add in the setup_context method?

  3. For computing higher-order gradients (for example 3rd order), I would need to define a triple-backward manually, right? Otherwise, the returned gradient would assume that the 3rd-order derivative is zero?

  4. With the setup_context method, I see that the ctx.save_for_backward is within that method rather than the forward method like standard torch.autograd.Function objects. Is this the new default syntax?

  5. Does the forward method need a ctx argument like the setup_context and backward methods or does it not as its a static method? (And the saving of tensors is now handled by the setup_context method instead of the forward method?

  6. Can I have a link to the pytorch slack? I remember you sent it a while back, but I can't find it! Thanks!

Edit:

  1. One function I want to implement contains the use of a functools.lru_cache, do you think this will work with vmap? I do have a way of circumventing the lru_cache with pytorch primitives, but I was just curious (as it works without vmap).

  2. Also, is this torch.autograd.Function fully supported with the torch.compile?

zou3519 commented 1 year ago

We have docs now, that should answer the questions :) https://pytorch.org/docs/master/notes/extending.func.html

Let me reply to them one-by-one:

Is the feature flag represented as the class attribute generate_vmap_rule=True?

No need to worry about the feature flag anymore. That class attributes toggles if you want to define a custom vmap staticmethod or if you want to ask PyTorch to generate one for you.

For my more complicated example (to represent a 2nd derivative), is it easy to nest this for a manually defined 2nd derivative? I.e., just defined another torch.autograd.Function object in the backward method of the first function, like previous implementations but just add in the setup_context method?

Yes, exactly.

For computing higher-order gradients (for example 3rd order), I would need to define a triple-backward manually, right? Otherwise, the returned gradient would assume that the 3rd-order derivative is zero?

It depends. If your second order backward involves PyTorch operations, then the 3rd-order derivative will be computed by PyTorch backpropping through those operations.

With the setup_context method, I see that the ctx.save_for_backward is within that method rather than the forward method like standard torch.autograd.Function objects. Is this the new default syntax?

Yes, the syntax is different. We are pushing this forward as a "new default syntax" for autograd.Function, but will support both syntaxes going forward.

Can I have a link to the pytorch slack? I remember you sent it a while back, but I can't find it! Thanks!

Can you email me your email and I'll manually add you? Mine is my github username @ gmail.com

Does the forward method need a ctx argument like the setup_context and backward methods or does it not as its a static method? (And the saving of tensors is now handled by the setup_context method instead of the forward method?

Forward method does not need a ctx argument. The saving of Tensors is now handled by setup_context.

One function I want to implement contains the use of a functools.lru_cache, do you think this will work with vmap? I do have a way of circumventing the lru_cache with pytorch primitives, but I was just curious (as it works without vmap).

Could you provide more details as what what you're doing with lru_cache? Just memoizing the inputs of the function?

Also, is this torch.autograd.Function fully supported with the torch.compile?

torch.autograd.Function is not optimized by torch.compile today in most cases. This new autograd.Function has the same support for torch.compile as the old one. Support for all autograd.Function with torch.compile is somewhere on our roadmap.

AlphaBetaGamma96 commented 1 year ago

Apologizes for the late reply @zou3519 !

We have docs now, that should answer the questions :) https://pytorch.org/docs/master/notes/extending.func.html

I shall give this a read!

Yes, the syntax is different. We are pushing this forward as a "new default syntax" for autograd.Function, but will support both syntaxes going forward.

I have attempted to port over my custom function to the new default syntax, although, I have noticed a broadcasting issue when computing gradients via vmap or .backward(). (I can share an example code, for clarity if needed!).

Can you email me your email and I'll manually add you? Mine is my github username @ gmail.com

Email has been sent now.

Could you provide more details as what what you're doing with lru_cache? Just memoizing the inputs of the function?

I've managed to create a pytorch version of the code that circumvents the need for the lru_cache. In short, it was used to speed up the calculation of a matrix (involved in the derivatives of a determinant function) that is constructed via permutations of a vector. It turns out it is quicker to use pytorch operations to compute it directly rather than utilize a lru_cache weirdly.