pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Error when applying `make_fx()` on a function that calls `optim.step()` #981

Open mostafaelhoushi opened 2 years ago

mostafaelhoushi commented 2 years ago

This code snippet used to pass but has recently started throwing an error:

    def test_make_fx_model_train_with_optim(self, device):
        class Foo(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(5, 5)

            def forward(self, x):
                return self.linear(x).relu()

        model = Foo()
        optim = torch.optim.SGD(model.parameters(), lr=1e-4)

        def f(args, params, buffers):
            if not isinstance(args, Iterable):
                args = [args]
            params_and_buffers = {**params, **buffers}
            out = stateless.functional_call(model, params_and_buffers, args)
            out.sum().backward()
            optim.step()

            # TODO: this causes graph to show an output with many incoming edges. Shall we try `return None` or simply don't return?
            return list(params.values())

        input = torch.randn(3, 5, requires_grad=True)
        params = dict(model.named_parameters())
        buffers = dict(model.named_buffers())
        fx_f = make_fx(f)(input, params, buffers)
        # TODO: what assert statement should we add here?
        assert(fx_f(input, params, buffers) is not None)

This is the error it throws:

  File "/Users/distiller/project/test/test_pythonkey.py", line 152, in test_make_fx_model_train_with_optim
    fx_f = make_fx(f)(input, params, buffers)
  File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 407, in wrapped
    t = dispatch_trace(wrap_key(f, args), tracer=fx_tracer, concrete_args=tuple(phs))
  File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 246, in dispatch_trace
    graph = tracer.trace(root, concrete_args)
  File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 714, in trace
    (self.create_arg(fn(*args)),),
  File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 549, in flatten_fn
    tree_out = root_fn(*tree_args)
  File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 270, in wrapped
    out = f(*tree_args)
  File "/Users/distiller/project/test/test_pythonkey.py", line 144, in f
    optim.step()
  File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
    with torch.autograd.profiler.record_function(profile_name):
  File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/autograd/profiler.py", line 477, in __exit__
    torch.ops.profiler._record_function_exit(self.handle)
  File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/_ops.py", line 164, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: Expected temporary cpp type wrapper of type at::RecordFunction