pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

functorch doesn't work with saved variable hooks #1027

Open zou3519 opened 1 year ago

zou3519 commented 1 year ago

I'm not sure why, cc @soulitzer @albanD.

Repro:

import torch
from torch.autograd.graph import save_on_cpu
from functorch import grad

x = torch.randn([], device='cuda', requires_grad=True)

def f(x):
    return x.sin().sin()

with save_on_cpu():
    y = f(x)
    gx, = torch.autograd.grad(y, x, create_graph=True)
assert gx.requires_grad

with save_on_cpu():
    gx = grad(f)(x)

# Fails
assert gx.requires_grad
zou3519 commented 1 year ago

This might be the same thing as https://docs.google.com/document/d/1xVRFtItMkIqs9eqMj2jqv-SQNoHZh8m-V71oSPsdKmc/edit