Open zou3519 opened 1 year ago
I'm not sure why, cc @soulitzer @albanD.
Repro:
import torch from torch.autograd.graph import save_on_cpu from functorch import grad x = torch.randn([], device='cuda', requires_grad=True) def f(x): return x.sin().sin() with save_on_cpu(): y = f(x) gx, = torch.autograd.grad(y, x, create_graph=True) assert gx.requires_grad with save_on_cpu(): gx = grad(f)(x) # Fails assert gx.requires_grad
This might be the same thing as https://docs.google.com/document/d/1xVRFtItMkIqs9eqMj2jqv-SQNoHZh8m-V71oSPsdKmc/edit
I'm not sure why, cc @soulitzer @albanD.
Repro: