Open AlphaBetaGamma96 opened 2 years ago
This is the same issue as #292 I think. I'm planning to put up a fix for that later today
Okay, I put up a fix for the specific error you saw. There's another issue though, which is that the implementation of slogdet backward calls tensor.item() and we are unable to vmap over it:
import torch
from functorch import make_functional, vmap, grad, jacrev
import torch.nn as nn
class model(nn.Module):
def __init__(self, num_input, num_hidden):
super(model, self).__init__()
self.num_input = num_input
self.num_hidden = num_hidden
self.layer1 = nn.Linear(2, num_hidden)
self.layer2 = nn.Linear(num_hidden, num_input)
def forward(self, x):
g = x.mean(dim=0, keepdim=True).repeat(self.num_input, 1)
f = torch.cat((x,g), dim=1)
y1 = self.layer1(f)
y2 = self.layer2(y1)
sgn, logabs = torch.slogdet(y2)
return sgn, logabs
net = model(2, 32)
func_model, params = make_functional(net)
def loss(params, x):
sgn, logabs = func_model(params, x)
return logabs
x = torch.randn(4096, 2, 1)
per_sample_grads = vmap(jacrev(loss), (None, 0))(params, x)
# RuntimeError: vmap: It looks like you're calling .item() on a Tensor. We don't support vmap over calling .item()
# on a Tensor, please try to rewrite what you're doing with other operations.
Yes, after posting I did notice that it looked pretty similar to #292
I only used torch.slogdet
within this example as a simplified version of my actual network which has a custom.autograd.Function
version of torch.slogdet
. As I manually defined the backward there, it shouldn't call .item()?
Should I try the following fix with my custom torch.slogdet
? Or would there be another issue with using custom functions?
autograd.Function is silently incorrect when used with functorch (https://github.com/pytorch/functorch/issues/207). So even if you manually define the backward, functorch just ignores it right now :(
That issue is very tricky to fix, but it's pretty high priority on our radar
Hi @zou3519,
I think I have a solution for removing .item
from slogdet_backward
. The only problem is it requires a batching rule for at::equal
. The backward I've just written is below,
Tensor slogdet_backward(const Tensor& grad_logabsdet,
const Tensor& self,
const Tensor& signdet, const Tensor& logabsdet) {
auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
Tensor u, sigma, vh;
std::tie(u, sigma, vh) = at::linalg_svd(self, false);
Tensor v = vh.mH();
// sigma has all non-negative entries (also with at least one zero entry)
// so logabsdet = \sum log(abs(sigma))
// but det = 0, so backward logabsdet = \sum log(sigma)
auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma);
return svd_backward({}, gsigma, {}, u, sigma, vh);
};
auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
// TODO: replace self.inverse with linalg_inverse
return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().mH();
};
if (self.dim() == 2) {
at::Tensor sing = at::zeros(1, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
//bool is_singular = self.is_complex() ? signdet.abs().item<double>() == 0 : signdet.item<double>() == 0; //removed
bool is_singular = self.is_complex() ? at::equal(signdet.abs(), sing) : at::equal(signdet, sing); //added
if (is_singular) {
return singular_case_backward(grad_logabsdet, self);
} else {
return nonsingular_case_backward(grad_logabsdet, self);
}
} else {
auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(self.is_complex() ? at::where(signdet.abs()) : at::where(signdet));
c10::optional<Tensor> first_nonzero_signdet_index = nonzero_signdet_indices[0];
if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular)
return nonsingular_case_backward(grad_logabsdet, self);
}
auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0));
c10::optional<Tensor> first_zero_signdet_index = zero_signdet_indices[0];
if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular)
return singular_case_backward(grad_logabsdet, self);
}
Tensor grad_slogdet = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// invertible case
grad_slogdet.index_put_(/*indices=*/nonzero_signdet_indices,
//NOLINTNEXTLINE(bugprone-argument-comment)
/*value=*/nonsingular_case_backward(grad_logabsdet.index(nonzero_signdet_indices),
self.index(nonzero_signdet_indices)));
// non-invertible case, uses SVD
grad_slogdet.index_put_(/*indices=*/zero_signdet_indices,
// NOLINTNEXTLINE(bugprone-argument-comment)
/*value=*/singular_case_backward(grad_logabsdet.index(zero_signdet_indices),
self.index(zero_signdet_indices)));
return grad_slogdet;
}
}
It works for calculating jacrev
of torch.linalg.slogdet
, however, it runs into the aforementioned BatchingRule error when using vmap. Here's an example:
import torch
from functorch import vmap, jacrev
def logabs(x):
return torch.linalg.slogdet(x)[1]
x = torch.randn(100,4,4)
jacobian = jacrev(logabs)(x) #return Tensor of shape [100,100,4,4]
torch.manual_seed(0)
x = torch.randn(4,4)
jacobian = jacrev(logabs)(x) #returns Tensor of shape [4,4]
"""
tensor([[-0.6138, -0.1843, 0.1957, -0.3358],
[ 0.2193, -0.0545, -0.0507, -0.3950],
[ 0.4635, -0.5424, 0.5355, -0.0715],
[-0.2209, 0.1132, 0.7572, -0.1647]])
"""
EDIT: (helps if I share the error)
>>> vmap(jacrev(logabs), in_dims=(0))(x)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "~/.local/lib/python3.9/site-packages/functorch/_src/vmap.py", line 383, in wrapped
batched_outputs = func(*batched_inputs, **kwargs)
File "~/.local/lib/python3.9/site-packages/functorch/_src/eager_transforms.py", line 440, in wrapper_fn
results = vmap(vjp_fn)(basis)
File "~/.local/lib/python3.9/site-packages/functorch/_src/vmap.py", line 383, in wrapped
batched_outputs = func(*batched_inputs, **kwargs)
File "~/.local/lib/python3.9/site-packages/functorch/_src/eager_transforms.py", line 288, in wrapper
result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
File "~/.local/lib/python3.9/site-packages/functorch/_src/eager_transforms.py", line 99, in _autograd_grad
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
File "~/pytorch_from_source/pytorch/torch/autograd/__init__.py", line 275, in grad
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Batching rule not implemented for aten::equal. We could not generate a fallback.
@AlphaBetaGamma96 -- are you still looking to use functorch on a custom autograd.Function? I have a prototype up and running and am looking for test cases :), I remember you had a custom determinant function
Hi @zou3519, I indeed am! Shall I upgrade to the latest nightly?
Thanks for the quick reply! The prototype is on a branch, we haven't shipped it to nightlies yet. Would it be possible to see the autograd.Function you were interested in using vmap with (and some sample inputs?)
Hi @zou3519, I do have an example in mind. I'll have a look for it and share a minimal reproducible example below as soon as I find it!
Hi @zou3519,
Sorry for this being quite late, but I had to rewrite it. I've attached an example below, which is run with the current versions of torch = 1.13.0.dev20221003
and CUDA = 11.6
.
In short, it's a custom torch.linalg.det
which can handle singular matrices via using a SVD-decomposition within its gradient. I do remember with an older version of PyTorch, torch.linalg.det
would fail with singular matrices but it seems to be fine with them now. So, perhaps the use-case of this example has been made redundant with recent updates, however, it's still a nice example.
My main goal would be to vmap
over a more complicated function (example on my github here) to compute the Laplacian of the determinant with respect to its input. Although that's a little more complicated!
I can add a custom torch.linalg.det function with a custom numerically stable DoubleBackward if you want to test nested torch.autograd.Function
too, just let me know!
import torch
from torch import Tensor
_ = torch.manual_seed(0)
import functorch
from functorch import vmap, jacrev
#helper functions here
#check of https://www.tensorflow.org/api_docs/python/tf/math/cumsum
def exclusive_cumsum(x: torch.Tensor, dim: int):
res = x.cumsum(dim=dim).roll(1) #roll(1) and res[...,0]=0. make it exclusive btw
res[...,0]=0.
return res
#not the same as exclusive_cumsum().flip()
#see here: https://www.tensorflow.org/api_docs/python/tf/math/cumsum
def reverse_exclusive_cumsum(x: torch.Tensor, dim: int):
res = x.flip(dim).cumsum(dim=dim)
res[...,-1]=0.
res=res.roll(1).flip(-1)
return res
def get_log_sigma(sigma: torch.Tensor):
return torch.log(sigma)
#modified from https://github.com/deepmind/ferminet/blob/tf/ferminet/networks.py#L143-L192
def get_log_gamma(log_sigma: torch.Tensor):
lower = exclusive_cumsum(log_sigma, dim=-1)
upper = reverse_exclusive_cumsum(log_sigma, dim=-1)
log_gamma = lower + upper
return log_gamma
class CustomDeterminant(torch.autograd.Function):
@staticmethod
def forward(ctx, matrix):
U, S, VT = torch.linalg.svd(matrix)
detU = torch.linalg.det(U)
detVT = torch.linalg.det(VT)
ctx.save_for_backward(U, S, VT, detU, detVT)
det = detU * detVT * torch.linalg.det(torch.diag_embed(S))
return det
@staticmethod
def backward(ctx, DetBar):
U, S, VT, detU, detVT = ctx.saved_tensors
ln_S = get_log_sigma(S)
gamma = torch.diag_embed( torch.exp(get_log_gamma(ln_S)) )
Abar = DetBar * detU * detVT * U @ gamma @ VT
return Abar
def pytorch_det(matrix: Tensor) -> Tensor:
return torch.linalg.det(matrix)
def custom_det(matrix: Tensor) -> Tensor:
return CustomDeterminant.apply(matrix)
#====================================================================================#
N = 4
matrix = torch.randn(N, N, dtype=torch.float64, requires_grad=True)
backward_check = torch.autograd.gradcheck(func=custom_det, inputs=(matrix), raise_exception=False)
print("grad check: ",backward_check)
naive_jac = torch.autograd.functional.jacobian(func=pytorch_det, inputs=(matrix))
custom_jac = torch.autograd.functional.jacobian(func=custom_det, inputs=(matrix))
print("\nCheck custom function vs naive function")
print("For larger matrices the difference will become more apparent\n")
for i in range(N):
print("Jacobian check: ", i, torch.allclose(naive_jac[i], custom_jac[i]))
print(naive_jac[i], custom_jac[i], "\n")
#====================================================================================#
which outputs,
grad check: True
Check custom function vs naive function
For larger matrices the difference will become more apparent
Jacobian check: 0 True
tensor([ 1.4697, 1.2719, -6.0033, 0.9463], dtype=torch.float64) tensor([ 1.4697, 1.2719, -6.0033, 0.9463], dtype=torch.float64)
Jacobian check: 1 True
tensor([-3.6909, -3.2754, 8.0690, -1.1910], dtype=torch.float64) tensor([-3.6909, -3.2754, 8.0690, -1.1910], dtype=torch.float64)
Jacobian check: 2 True
tensor([-2.3848, -1.7151, 6.2400, 0.4697], dtype=torch.float64) tensor([-2.3848, -1.7151, 6.2400, 0.4697], dtype=torch.float64)
Jacobian check: 3 True
tensor([ 0.4287, -1.9187, 1.1866, 1.5338], dtype=torch.float64) tensor([ 0.4287, -1.9187, 1.1866, 1.5338], dtype=torch.float64)
No need to apologize and thank you for the example! Time to test it out and see how it goes
Hi @zou3519, any update on this?
Hey @AlphaBetaGamma96,
Sorry for the delayed reply, I was out on vacation last week. I haven't gotten around to testing your example, but here's a quick update:
There are two types of autograd.Function we want to add support for for functorch:
Your use case is (2).
Implementing (1) is on the way to implementing (2). I'm currently working through shipping 1 to PyTorch (some work is being done over at https://github.com/pytorch/pytorch/pull/88785) and (2) is coming soon after. The goal is to complete this work before the end of the year.
@AlphaBetaGamma96 -- I've been testing your example and everything looks good so far. Here's a quick summary if you're interested in trying it out more:
forward
function into a forward
and a separate setup_context
staticmethod.Would be curious to hear your feedback before we officially release it in the next release. Documentation will be coming soon, but happy to answer questions
Hi @zou3519, apologies for the late response. The refactor of the example looks great! Can I ask a few questions?
Is the feature flag represented as the class attribute generate_vmap_rule=True
?
For my more complicated example (to represent a 2nd derivative), is it easy to nest this for a manually defined 2nd derivative? I.e., just defined another torch.autograd.Function
object in the backward method of the first function, like previous implementations but just add in the setup_context
method?
For computing higher-order gradients (for example 3rd order), I would need to define a triple-backward manually, right? Otherwise, the returned gradient would assume that the 3rd-order derivative is zero?
With the setup_context
method, I see that the ctx.save_for_backward
is within that method rather than the forward
method like standard torch.autograd.Function
objects. Is this the new default syntax?
Does the forward
method need a ctx
argument like the setup_context
and backward
methods or does it not as its a static method? (And the saving of tensors is now handled by the setup_context
method instead of the forward
method?
Can I have a link to the pytorch slack? I remember you sent it a while back, but I can't find it! Thanks!
Edit:
One function I want to implement contains the use of a functools.lru_cache
, do you think this will work with vmap? I do have a way of circumventing the lru_cache
with pytorch primitives, but I was just curious (as it works without vmap).
Also, is this torch.autograd.Function
fully supported with the torch.compile
?
We have docs now, that should answer the questions :) https://pytorch.org/docs/master/notes/extending.func.html
Let me reply to them one-by-one:
Is the feature flag represented as the class attribute generate_vmap_rule=True?
No need to worry about the feature flag anymore. That class attributes toggles if you want to define a custom vmap staticmethod or if you want to ask PyTorch to generate one for you.
For my more complicated example (to represent a 2nd derivative), is it easy to nest this for a manually defined 2nd derivative? I.e., just defined another torch.autograd.Function object in the backward method of the first function, like previous implementations but just add in the setup_context method?
Yes, exactly.
For computing higher-order gradients (for example 3rd order), I would need to define a triple-backward manually, right? Otherwise, the returned gradient would assume that the 3rd-order derivative is zero?
It depends. If your second order backward involves PyTorch operations, then the 3rd-order derivative will be computed by PyTorch backpropping through those operations.
With the setup_context method, I see that the ctx.save_for_backward is within that method rather than the forward method like standard torch.autograd.Function objects. Is this the new default syntax?
Yes, the syntax is different. We are pushing this forward as a "new default syntax" for autograd.Function, but will support both syntaxes going forward.
Can I have a link to the pytorch slack? I remember you sent it a while back, but I can't find it! Thanks!
Can you email me your email and I'll manually add you? Mine is my github username @ gmail.com
Does the forward method need a ctx argument like the setup_context and backward methods or does it not as its a static method? (And the saving of tensors is now handled by the setup_context method instead of the forward method?
Forward method does not need a ctx
argument. The saving of Tensors is now handled by setup_context
.
One function I want to implement contains the use of a functools.lru_cache, do you think this will work with vmap? I do have a way of circumventing the lru_cache with pytorch primitives, but I was just curious (as it works without vmap).
Could you provide more details as what what you're doing with lru_cache? Just memoizing the inputs of the function?
Also, is this torch.autograd.Function fully supported with the torch.compile?
torch.autograd.Function is not optimized by torch.compile today in most cases. This new autograd.Function has the same support for torch.compile as the old one. Support for all autograd.Function with torch.compile is somewhere on our roadmap.
Apologizes for the late reply @zou3519 !
We have docs now, that should answer the questions :) https://pytorch.org/docs/master/notes/extending.func.html
I shall give this a read!
Yes, the syntax is different. We are pushing this forward as a "new default syntax" for autograd.Function, but will support both syntaxes going forward.
I have attempted to port over my custom function to the new default syntax, although, I have noticed a broadcasting issue when computing gradients via vmap
or .backward()
. (I can share an example code, for clarity if needed!).
Can you email me your email and I'll manually add you? Mine is my github username @ gmail.com
Email has been sent now.
Could you provide more details as what what you're doing with lru_cache? Just memoizing the inputs of the function?
I've managed to create a pytorch version of the code that circumvents the need for the lru_cache
. In short, it was used to speed up the calculation of a matrix (involved in the derivatives of a determinant function) that is constructed via permutations of a vector. It turns out it is quicker to use pytorch operations to compute it directly rather than utilize a lru_cache weirdly.
Hi All,
I've been trying to get per-sample gradients of a simple feed-forward network that contains a
torch.slogdet
as its final layer. When I go to applyvmap
tojacrev
the number of arguments of the function changes and I have no idea why. I followed the example of here withvmap
andjacrev
yet it seems to fail for some reason.The minimal reproduction example is here
The Traceback is below,