Open TimSchneider42 opened 3 years ago
Hi @TimSchneider42, yes that is concerning, thank you for the report! We'll take a look. If it helps unblock you, using the updated API of subclassing torch.nn.Module instead of ScriptModule appears to make this problem go away:
class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.var = Variable(torch.randn(2), requires_grad=True)
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
var_lst = [self.var]
var = torch.cat(var_lst)
output = var + input
print("forward: var_lst[0].requires_grad = {}".format(var_lst[0].requires_grad))
return output, var_lst
Hey @Lilyjjo, thanks a lot for this recommendation. Do you see any way I can use the JIT compiler on a class with multiple functions (not just the single forward function of torch.nn.Module) without running into the problem I described? Best, Tim
@TimSchneider42, are you trying to JIT a non module type? For torch.nn.Module types you can use the @torch.jit.export
decorator on methods in torch.nn.Module to be able to call them directly. https://pytorch.org/docs/stable/jit.html?highlight=export#torch.jit.export
@Lilyjjo, that is correct, I am trying to JIT compile a class with a couple of functions but no forward function. I tried as you advice to make the class a torch.nn.Module
, leaving the forward
function blank and simply decorating all functions with @torch.jit.export
. However, in my setup, the same problem remains: the forward pass is fine, but the backward pass misses parts of the graph.
This example illustrates my approach:
from typing import Tuple, List
import torch
from torch.autograd import Variable
class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.var = Variable(torch.randn(2), requires_grad=True)
@torch.jit.export
def some_func(self, input: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
var_lst = [self.var]
var = torch.cat(var_lst)
output = var + input
print("forward: var_lst[0].requires_grad = {}".format(var_lst[0].requires_grad))
return output, var_lst
mod = torch.jit.script(Mod())
for i in range(2):
print("\nIteration {}".format(i))
input = torch.randn((2,))
output, var_lst = mod.some_func(input)
print("main: var_lst[0].requires_grad = {}".format(var_lst[0].requires_grad))
loss = torch.cat(var_lst).sum()
loss.backward()
cc: @Krovatkin could this be related to the autodiff changes you were working on?
@Lilyjjo yeah this looks like a dup of an internal issue.
@TimSchneider42 we backed out the offending change for now. The testcase should work on master now.
I'm still seeing this issue on commit e3900d2ba5c9f91a24a9ce34520794c8366d5c54
I assume we are referring to the offending commit as the one with require_grad pruning f88a3fff65b35cb6d4968fc54a9a0a1314a9a3b7.
With a closer look, I don't think that's causing the issue here. Looks like the issue here is specific to TensorList passed to DifferentiableGraph. The graph I printed from the example above (we can ignore what's inside the DifferentiableGraph, it's not relevant here).
graph(%self : __torch__.Mod,
%input.1 : Float(2, strides=[1], requires_grad=0, device=cpu)):
%3 : str = prim::Constant[value="forward: var_lst[0].requires_grad = {}"]() # rg4.py:19:14
%2 : str = prim::Constant[value="forward: var.requires_grad = {}"]() # rg4.py:20:14
%6 : Tensor = prim::GetAttr[name="var"](%self)
%var_lst.1 : Tensor[] = prim::ListConstruct(%6)
%31 : Float(2, strides=[1], requires_grad=0, device=cpu), %32 : bool = prim::RequiresGradCheck[types=[Tensor(requires_grad=0)]](%input.1)
%33 : Tensor, %34 : Tensor = prim::If(%32)
block0():
%output.6 : Tensor, %25 : Tensor = prim::DifferentiableGraph_0(%31, %var_lst.1)
-> (%output.6, %25)
block1():
%43 : Function = prim::Constant[name="fallback_function", fallback=1]()
%44 : (Tensor, Tensor) = prim::CallFunction(%43, %input.1, %var_lst.1)
%45 : Tensor, %46 : Tensor = prim::TupleUnpack(%44)
-> (%45, %46)
%13 : bool = prim::requires_grad(%6)
%14 : str = aten::format(%3, %13) # rg4.py:19:14
= prim::Print(%14) # rg4.py:19:8
%16 : bool = prim::requires_grad(%34)
%17 : str = aten::format(%2, %16) # rg4.py:20:14
= prim::Print(%17) # rg4.py:20:8
%19 : (Tensor, Tensor[]) = prim::TupleConstruct(%33, %var_lst.1)
return (%19)
%var_lst.1
is fed to DifferentiableGraph, and later used to construct the output of the graph.
When we run the differentiable graph, we detach all inputs. I think c10::List shares storage, so in current code when we detach TensorList, we unintentionally mutated the input IValue: https://github.com/pytorch/pytorch/blob/0ea4eb745b4764e14d10f541887a12beb9949227/torch/csrc/jit/runtime/graph_executor.cpp#L439-L443
I have verified a quick dirty WAR on my machine: https://github.com/pytorch/pytorch/pull/56466/commits/dfcf4fed2c62403deab4e9df6ca502075d12f2e0
š Bug
I encountered really strange behavior when using the JIT compiler to speed up my code. In particular, I noticed that when using the JIT compiler, the backward computation graph was missing some branches compared to the non-JIT compiled version. In the code below I isolated a minimal example in which the problem appears to occur:
To Reproduce
Steps to reproduce the behavior:
In this example there is one variable
var
encapsulated in theScriptModule
Mod
that hasrequires_grad=True
set. The variable is used to compute two outputs,output
andvar_lst
, which is just the variable wrapped in a python list. Hence, both outputs depend on the variable and should hence haverequires_grad
set to true. However, if I run the above code, I get the following output:followed by an error, because
loss.backward()
cannot be computed whenloss.requires_grad
is false.Interestingly, it seems that changing anything in this example leads to the error disappearing. For example removing the computation of
output
or returningoutput, [self.var]
instead ofoutput, var_lst
causesvar_lst[0].requires_grad
to be true again.Expected behavior
If the code is run with
PYTORCH_JIT=0
, everything works as expected and no error occurs:Environment
Collecting environment information... PyTorch version: 1.8.1 Is debug build: False CUDA used to build PyTorch: 11.2 ROCM used to build PyTorch: N/A
OS: Arch Linux (x86_64) GCC version: (GCC) 10.2.0 Clang version: 11.1.0 CMake version: version 3.20.0
Python version: 3.9 (64-bit runtime) Is CUDA available: True CUDA runtime version: 10.1.243 GPU models and configuration: GPU 0: GeForce GTX 1070 Nvidia driver version: 460.67 cuDNN version: Probably one of the following: /usr/lib/libcudnn.so.8.1.0 /usr/lib/libcudnn_adv_infer.so.8.1.0 /usr/lib/libcudnn_adv_train.so.8.1.0 /usr/lib/libcudnn_cnn_infer.so.8.1.0 /usr/lib/libcudnn_cnn_train.so.8.1.0 /usr/lib/libcudnn_ops_infer.so.8.1.0 /usr/lib/libcudnn_ops_train.so.8.1.0 HIP runtime version: N/A MIOpen runtime version: N/A
Versions of relevant libraries: [pip3] numpy==1.20.1 [pip3] torch==1.8.1 [pip3] torch-scatter==2.0.6 [conda] Could not collect
Thanks a lot in advance, Tim
cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @Lezcano @Varal7 @gmagogsfm