pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Applying grad elementwise to tensors of arbitrary shape #1070

Closed EmilienDupont closed 1 year ago

EmilienDupont commented 1 year ago

What is the easiest way to apply the grad of a function elementwise to a tensor of arbitrary shape? For example

import torch
from functorch import grad, vmap

# These functions can be called with tensor of any shape and will be applied elementwise
sin = torch.sin
cos = torch.cos

# Create cos function by using grad
cos_from_grad = grad(sin)

x = torch.rand([4, 2])

# This is fine
out = sin(x)
out = cos(x)

# This throws error
# Expected f(*args) to return a scalar Tensor, got tensor with 2 dims
out = cos_from_grad(x)

Now in this specific case, where we have a tensor of shape (4, 2), we can use vmap twice

cos_from_grad = vmap(vmap(grad(sin)))

# This now works
out = cos_from_grad(x)

However, if I later need to call cos_from_grad on a tensor of shape (4, 2, 3) for example, then the above code will no longer work as I would need to add an extra vmap. Is there a way to use grad to create a cos function that is equivalent to torch.cos in the sense that it can be applied elementwise to tensors of arbitrary shape?

Thank you!

samdow commented 1 year ago

Yeah so to clarify, the issue happens when the output of the function is not a scalar tensor (Size []). What the double vmap is doing is treating the input like a scalar tensor, so then the output is seen as a scalar. There's a couple options here:

(1) If you always want each output value to be treated equally, you can do cos_from_grad = grad(lambda x: sin(x).sum()) this function will always produce a scalar value

(2) If you want to treat the output values differently, try using vjp vjp(sin, x)[1](tangents) Here the tangents will be the seed for backprop, so it should be the same size as the output. Specifically, if you want to get the same behavior as the summed version but with vjp:

x = torch.randn([4, 2])
out, vjp_fn = vjp(sin, x)
grad = vjp_fn(torch.ones_like(out))

Then by adjusting what you pass to vjp_fn, you can change the weight each output is given in the backprop computation.

Hope that helped and let us know if you have more questions!

EmilienDupont commented 1 year ago

Thank you for your reply! The first option is exactly what I needed :)

I've tested it in my code and it behaves as expected. However, just for my own understanding, I had a question about what exactly happens when defining cos_from_grad = grad(lambda x: torch.sin(x).sum()). If I for example pass a tensor x = torch.Tensor([0, 1]) to cos_from_grad, I imagined that grad would return a function that is equivalent to torch.cos(x).sum() and hence that the output would be torch.cos(0) + torch.cos(1). However, as this is not what happens I'm curious to understand what goes on under the hood in this case?

Also, will

  1. cos_from_grad = grad(lambda x: torch.sin(x).sum())
  2. cos = torch.cos

behave in exactly the same way?

Thank you for your help and sorry for my very rudimentary understanding of functorch!

samdow commented 1 year ago

I imagined that grad would return a function that is equivalent to torch.cos(x).sum() and hence that the output would be torch.cos(0) + torch.cos(1). However, as this is not what happens I'm curious to understand what goes on under the hood in this case?

Yeah let's break this down a bit. So grad is getting the gradient with respect to x of torch.sin(x).sum(). By definition, this must be the same shape as x. One thing to remember is that we're doing matrix calculus here instead of scalar calculus. So, we're getting the derivative wrt each input separately. Breaking it down a little more, we're doing

x = torch.Tensor([x0, x1])
y = torch.sin(x).sum()  # torch.sin(x0) + torch.sin(x1)

Then when we get the gradient of y with respect to x, we're actually computing the derivative with respect to each of its elements separately (d_y/d_x0 is derivative of y with respect to x0)

grad = torch.Tensor([d_y/d_x0, d_y/d_x1]) # note each takes the derivative of y wrt one element of x 

# replace y with torch.sin(x0) + torch.sin(x1)
grad = torch.Tensor([d(torch.sin(x0) + torch.sin(x1))/d_x0, d(torch.sin(x0) + torch.sin(x1))/d_x0])

# derivative of  torch.sin(x0) + torch.sin(x1) wrt x0 is just torch.cos(x0) since torch.sin(x1) is not affected by x0!
# and vice a versa for x1
grad = torch.Tensor([torch.cos(x0), torch.cos(x1)])

Also, will cos_from_grad = grad(lambda x: torch.sin(x).sum()) cos = torch.cos behave in exactly the same way?

In this case, yes! There are cases where the derivatives are not well defined and we might get NaNs or errors, but sin is a well behaved function. We can see the derivative for sin defined here and the rest of this file has the derivatives for other functions

EmilienDupont commented 1 year ago

This is great!! Thank you so much for the clear explanation - it all makes sense now :)