Closed EmilienDupont closed 1 year ago
Yeah so to clarify, the issue happens when the output of the function is not a scalar tensor (Size []). What the double vmap is doing is treating the input like a scalar tensor, so then the output is seen as a scalar. There's a couple options here:
(1) If you always want each output value to be treated equally, you can do
cos_from_grad = grad(lambda x: sin(x).sum())
this function will always produce a scalar value
(2) If you want to treat the output values differently, try using vjp
vjp(sin, x)[1](tangents)
Here the tangents will be the seed for backprop, so it should be the same size as the output. Specifically, if you want to get the same behavior as the summed version but with vjp:
x = torch.randn([4, 2])
out, vjp_fn = vjp(sin, x)
grad = vjp_fn(torch.ones_like(out))
Then by adjusting what you pass to vjp_fn
, you can change the weight each output is given in the backprop computation.
Hope that helped and let us know if you have more questions!
Thank you for your reply! The first option is exactly what I needed :)
I've tested it in my code and it behaves as expected. However, just for my own understanding, I had a question about what exactly happens when defining cos_from_grad = grad(lambda x: torch.sin(x).sum())
. If I for example pass a tensor x = torch.Tensor([0, 1])
to cos_from_grad
, I imagined that grad
would return a function that is equivalent to torch.cos(x).sum()
and hence that the output would be torch.cos(0) + torch.cos(1)
. However, as this is not what happens I'm curious to understand what goes on under the hood in this case?
Also, will
cos_from_grad = grad(lambda x: torch.sin(x).sum())
cos = torch.cos
behave in exactly the same way?
Thank you for your help and sorry for my very rudimentary understanding of functorch
!
I imagined that grad would return a function that is equivalent to torch.cos(x).sum() and hence that the output would be torch.cos(0) + torch.cos(1). However, as this is not what happens I'm curious to understand what goes on under the hood in this case?
Yeah let's break this down a bit. So grad is getting the gradient with respect to x of torch.sin(x).sum(). By definition, this must be the same shape as x
. One thing to remember is that we're doing matrix calculus here instead of scalar calculus. So, we're getting the derivative wrt each input separately. Breaking it down a little more, we're doing
x = torch.Tensor([x0, x1])
y = torch.sin(x).sum() # torch.sin(x0) + torch.sin(x1)
Then when we get the gradient of y with respect to x, we're actually computing the derivative with respect to each of its elements separately (d_y/d_x0 is derivative of y with respect to x0)
grad = torch.Tensor([d_y/d_x0, d_y/d_x1]) # note each takes the derivative of y wrt one element of x
# replace y with torch.sin(x0) + torch.sin(x1)
grad = torch.Tensor([d(torch.sin(x0) + torch.sin(x1))/d_x0, d(torch.sin(x0) + torch.sin(x1))/d_x0])
# derivative of torch.sin(x0) + torch.sin(x1) wrt x0 is just torch.cos(x0) since torch.sin(x1) is not affected by x0!
# and vice a versa for x1
grad = torch.Tensor([torch.cos(x0), torch.cos(x1)])
Also, will cos_from_grad = grad(lambda x: torch.sin(x).sum()) cos = torch.cos behave in exactly the same way?
In this case, yes! There are cases where the derivatives are not well defined and we might get NaNs or errors, but sin is a well behaved function. We can see the derivative for sin defined here and the rest of this file has the derivatives for other functions
This is great!! Thank you so much for the clear explanation - it all makes sense now :)
What is the easiest way to apply the grad of a function elementwise to a tensor of arbitrary shape? For example
Now in this specific case, where we have a tensor of shape
(4, 2)
, we can usevmap
twiceHowever, if I later need to call
cos_from_grad
on a tensor of shape(4, 2, 3)
for example, then the above code will no longer work as I would need to add an extravmap
. Is there a way to usegrad
to create acos
function that is equivalent totorch.cos
in the sense that it can be applied elementwise to tensors of arbitrary shape?Thank you!