pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Get intermediate derivatives with nested jacobian and has_aux #1037

Closed i-a-morozov closed 1 year ago

i-a-morozov commented 1 year ago

Is it possible to get intermediate results with nested jacobian? Say functorch.jacfwdis nested twice with has_aux=True, how to get 1st derivative in this case?

import torch
import functorch

def foo(x):
    y = torch.cos(x)
    return y, y

def nest(fun, num):
    bar = fun
    for _ in range(num):
        bar = functorch.jacfwd(bar, has_aux=True)
    return bar

x = torch.tensor(0.0)

print(nest(foo, 1)(x))
# 1st derivative and value
# (tensor(-0.), tensor(1.000000000000e+00))

print(nest(foo, 2)(x))
# 2nd derivative and value, no 1st derivative
# (tensor(-1.000000000000e+00), tensor(1.000000000000e+00))
zou3519 commented 1 year ago

With jacfwd nested twice -- we want the first derivative computation to return (first_derivative, (first_derivative, value)). Then, applying jacfwd with hax_aux=True to that function, we'll get the second_derivative and well as the aux output (the (first_derivative, value) tuple).

import torch
import functorch

def foo(x):
    y = torch.cos(x)
    return y, y

def first(x):
    dx, value = functorch.jacfwd(foo, has_aux=True)(x)
    return dx, (dx, value)

ddx, (dx, value) = functorch.jacfwd(first, has_aux=True)(x)
print(ddx)
print(dx)
print(value)
i-a-morozov commented 1 year ago

@zou3519 , great, thank you!

Here is a sloppy version for higher orders:

import torch
import functorch

x = torch.tensor(0.0)

def foo(x):
    y = x + x**2 + x**3 + x**4 + x**5
    return y, y

num = 5
bar = foo
for _ in range(num):
    def bar(x, bar=bar):
        y, ys = functorch.jacfwd(bar, has_aux=True)(x)
        return y, (y, ys)

_, y = bar(x)
print(y)
# (dddddx,       (ddddx,       (dddx,       (ddx,        (dx,         x)))))
# (tensor(120.), (tensor(24.), (tensor(6.), (tensor(2.), (tensor(1.), tensor(0.))))))