pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.29k stars 22.46k forks source link

requires_grad does not get propagated properly when using the JIT compiler #55609

Open TimSchneider42 opened 3 years ago

TimSchneider42 commented 3 years ago

šŸ› Bug

I encountered really strange behavior when using the JIT compiler to speed up my code. In particular, I noticed that when using the JIT compiler, the backward computation graph was missing some branches compared to the non-JIT compiled version. In the code below I isolated a minimal example in which the problem appears to occur:

To Reproduce

Steps to reproduce the behavior:

from typing import Tuple, List

import torch
from torch.autograd import Variable
from torch.jit import ScriptModule

class Mod(ScriptModule):
    def __init__(self):
        super(Mod, self).__init__()
        self.var = Variable(torch.randn(2), requires_grad=True)

    @torch.jit.script_method
    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        var_lst = [self.var]
        var = torch.cat(var_lst)
        output = var + input
        print("forward: var_lst[0].requires_grad = {}".format(var_lst[0].requires_grad))
        return output, var_lst

mod = Mod()

for i in range(2):
    print("\nIteration {}".format(i))
    input = torch.randn((2,))
    output, var_lst = mod.forward(input)
    print("main: var_lst[0].requires_grad = {}".format(var_lst[0].requires_grad))
    loss = torch.cat(var_lst).sum()
    loss.backward()

In this example there is one variable var encapsulated in the ScriptModule Mod that has requires_grad=True set. The variable is used to compute two outputs, output and var_lst, which is just the variable wrapped in a python list. Hence, both outputs depend on the variable and should hence have requires_grad set to true. However, if I run the above code, I get the following output:

Iteration 0
forward: var_lst[0].requires_grad = True
main: var_lst[0].requires_grad = True

Iteration 1
forward: var_lst[0].requires_grad = True
main: var_lst[0].requires_grad = False

followed by an error, because loss.backward() cannot be computed when loss.requires_grad is false.

Interestingly, it seems that changing anything in this example leads to the error disappearing. For example removing the computation of output or returning output, [self.var] instead of output, var_lst causes var_lst[0].requires_grad to be true again.

Expected behavior

If the code is run with PYTORCH_JIT=0, everything works as expected and no error occurs:

Iteration 0
forward: var_lst[0].requires_grad = True
main: var_lst[0].requires_grad = True

Iteration 1
forward: var_lst[0].requires_grad = True
main: var_lst[0].requires_grad = True

Environment

Collecting environment information... PyTorch version: 1.8.1 Is debug build: False CUDA used to build PyTorch: 11.2 ROCM used to build PyTorch: N/A

OS: Arch Linux (x86_64) GCC version: (GCC) 10.2.0 Clang version: 11.1.0 CMake version: version 3.20.0

Python version: 3.9 (64-bit runtime) Is CUDA available: True CUDA runtime version: 10.1.243 GPU models and configuration: GPU 0: GeForce GTX 1070 Nvidia driver version: 460.67 cuDNN version: Probably one of the following: /usr/lib/libcudnn.so.8.1.0 /usr/lib/libcudnn_adv_infer.so.8.1.0 /usr/lib/libcudnn_adv_train.so.8.1.0 /usr/lib/libcudnn_cnn_infer.so.8.1.0 /usr/lib/libcudnn_cnn_train.so.8.1.0 /usr/lib/libcudnn_ops_infer.so.8.1.0 /usr/lib/libcudnn_ops_train.so.8.1.0 HIP runtime version: N/A MIOpen runtime version: N/A

Versions of relevant libraries: [pip3] numpy==1.20.1 [pip3] torch==1.8.1 [pip3] torch-scatter==2.0.6 [conda] Could not collect

Thanks a lot in advance, Tim

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @Lezcano @Varal7 @gmagogsfm

Lilyjjo commented 3 years ago

Hi @TimSchneider42, yes that is concerning, thank you for the report! We'll take a look. If it helps unblock you, using the updated API of subclassing torch.nn.Module instead of ScriptModule appears to make this problem go away:

class Mod(torch.nn.Module):
    def __init__(self):
        super(Mod, self).__init__()
        self.var = Variable(torch.randn(2), requires_grad=True)

    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        var_lst = [self.var]
        var = torch.cat(var_lst)
        output = var + input
        print("forward: var_lst[0].requires_grad = {}".format(var_lst[0].requires_grad))
        return output, var_lst
TimSchneider42 commented 3 years ago

Hey @Lilyjjo, thanks a lot for this recommendation. Do you see any way I can use the JIT compiler on a class with multiple functions (not just the single forward function of torch.nn.Module) without running into the problem I described? Best, Tim

Lilyjjo commented 3 years ago

@TimSchneider42, are you trying to JIT a non module type? For torch.nn.Module types you can use the @torch.jit.export decorator on methods in torch.nn.Module to be able to call them directly. https://pytorch.org/docs/stable/jit.html?highlight=export#torch.jit.export

TimSchneider42 commented 3 years ago

@Lilyjjo, that is correct, I am trying to JIT compile a class with a couple of functions but no forward function. I tried as you advice to make the class a torch.nn.Module, leaving the forward function blank and simply decorating all functions with @torch.jit.export. However, in my setup, the same problem remains: the forward pass is fine, but the backward pass misses parts of the graph.

This example illustrates my approach:

from typing import Tuple, List

import torch
from torch.autograd import Variable

class Mod(torch.nn.Module):
    def __init__(self):
        super(Mod, self).__init__()
        self.var = Variable(torch.randn(2), requires_grad=True)

    @torch.jit.export
    def some_func(self, input: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        var_lst = [self.var]
        var = torch.cat(var_lst)
        output = var + input
        print("forward: var_lst[0].requires_grad = {}".format(var_lst[0].requires_grad))
        return output, var_lst

mod = torch.jit.script(Mod())

for i in range(2):
    print("\nIteration {}".format(i))
    input = torch.randn((2,))
    output, var_lst = mod.some_func(input)
    print("main: var_lst[0].requires_grad = {}".format(var_lst[0].requires_grad))
    loss = torch.cat(var_lst).sum()
    loss.backward()
Lilyjjo commented 3 years ago

cc: @Krovatkin could this be related to the autodiff changes you were working on?

Krovatkin commented 3 years ago

@Lilyjjo yeah this looks like a dup of an internal issue.

Krovatkin commented 3 years ago

@TimSchneider42 we backed out the offending change for now. The testcase should work on master now.

jjsjann123 commented 3 years ago

I'm still seeing this issue on commit e3900d2ba5c9f91a24a9ce34520794c8366d5c54

I assume we are referring to the offending commit as the one with require_grad pruning f88a3fff65b35cb6d4968fc54a9a0a1314a9a3b7.

With a closer look, I don't think that's causing the issue here. Looks like the issue here is specific to TensorList passed to DifferentiableGraph. The graph I printed from the example above (we can ignore what's inside the DifferentiableGraph, it's not relevant here).

graph(%self : __torch__.Mod,
      %input.1 : Float(2, strides=[1], requires_grad=0, device=cpu)):
  %3 : str = prim::Constant[value="forward: var_lst[0].requires_grad = {}"]() # rg4.py:19:14
  %2 : str = prim::Constant[value="forward: var.requires_grad = {}"]() # rg4.py:20:14
  %6 : Tensor = prim::GetAttr[name="var"](%self)
  %var_lst.1 : Tensor[] = prim::ListConstruct(%6)
  %31 : Float(2, strides=[1], requires_grad=0, device=cpu), %32 : bool = prim::RequiresGradCheck[types=[Tensor(requires_grad=0)]](%input.1)
  %33 : Tensor, %34 : Tensor = prim::If(%32)
    block0():
      %output.6 : Tensor, %25 : Tensor = prim::DifferentiableGraph_0(%31, %var_lst.1)
      -> (%output.6, %25)
    block1():
      %43 : Function = prim::Constant[name="fallback_function", fallback=1]()
      %44 : (Tensor, Tensor) = prim::CallFunction(%43, %input.1, %var_lst.1)
      %45 : Tensor, %46 : Tensor = prim::TupleUnpack(%44)
      -> (%45, %46)
  %13 : bool = prim::requires_grad(%6)
  %14 : str = aten::format(%3, %13) # rg4.py:19:14
   = prim::Print(%14) # rg4.py:19:8
  %16 : bool = prim::requires_grad(%34)
  %17 : str = aten::format(%2, %16) # rg4.py:20:14
   = prim::Print(%17) # rg4.py:20:8
  %19 : (Tensor, Tensor[]) = prim::TupleConstruct(%33, %var_lst.1)
  return (%19)

%var_lst.1 is fed to DifferentiableGraph, and later used to construct the output of the graph.

When we run the differentiable graph, we detach all inputs. I think c10::List shares storage, so in current code when we detach TensorList, we unintentionally mutated the input IValue: https://github.com/pytorch/pytorch/blob/0ea4eb745b4764e14d10f541887a12beb9949227/torch/csrc/jit/runtime/graph_executor.cpp#L439-L443

I have verified a quick dirty WAR on my machine: https://github.com/pytorch/pytorch/pull/56466/commits/dfcf4fed2c62403deab4e9df6ca502075d12f2e0