pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.35k stars 22.47k forks source link

[dynamo] Recompilation for nn module guards when it should not #110048

Closed anijain2305 closed 1 year ago

anijain2305 commented 1 year ago

🐛 Describe the bug

Repro

import torch

mod = torch.nn.Linear(10, 10).cuda()

@torch.compile(backend="aot_eager")
def generate(x, c):
    return mod(x) + c

for _ in range(0, 10):
    # print(id(mod.lm_head), torch._C._dynamo.guards.nn_module_guard(mod.lm_head))
    generate(torch.randn(10, 10, device="cuda"), 0)
    generate(torch.randn(10, 10, device="cuda"), 1)

The failing guard here is __nn_module_guard. The failure reason is that the _parameters dict keeps changing its version.

 [0/0] torch._dynamo.guards.__guards: [DEBUG] __nn_module_guard_0(G['mod'], debug_msg="versions(mod=684814, _parameters=1384576, _buffers=684785, ....
 [0/1] torch._dynamo.guards.__guards: [DEBUG] __nn_module_guard_0(G['mod'], debug_msg="versions(mod=684814, _parameters=1480788, _buffers=684785, ....
 [0/2] torch._dynamo.guards.__guards: [DEBUG] __nn_module_guard_0(G['mod'], debug_msg="versions(mod=684814, _parameters=1520867, _buffers=684785, ....
 [0/3] torch._dynamo.guards.__guards: [DEBUG] __nn_module_guard_0(G['mod'], debug_msg="versions(mod=684814, _parameters=1560718, _buffers=684785, ....
 [0/4] torch._dynamo.guards.__guards: [DEBUG] __nn_module_guard_0(G['mod'], debug_msg="versions(mod=684814, _parameters=1600793, _buffers=684785, ....
 [0/5] torch._dynamo.guards.__guards: [DEBUG] __nn_module_guard_0(G['mod'], debug_msg="versions(mod=684814, _parameters=1640646, _buffers=684785, ....
 [0/6] torch._dynamo.guards.__guards: [DEBUG] __nn_module_guard_0(G['mod'], debug_msg="versions(mod=684814, _parameters=1680736, _buffers=684785, ....
 [0/7] torch._dynamo.guards.__guards: [DEBUG] __nn_module_guard_0(G['mod'], debug_msg="versions(mod=684814, _parameters=1720591, _buffers=684785, ....

However, there is nothing in the model that should change the mod _parameters.

Error logs

No response

Minified repro

No response

Versions

NA

cc @ezyang @msaroufim @wconstab @bdhirsh @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng

anijain2305 commented 1 year ago

cc @Chillee

anijain2305 commented 1 year ago

Also it does not happen with eager backend, it happens only with aot_eager backend.

anijain2305 commented 1 year ago

The _parameters dict changes because of this line - https://github.com/pytorch/pytorch/blob/5dcee01c2b89f6bedeef9dd043fd8d6728286582/torch/nn/utils/_named_member_accessor.py#L52 called during the https://github.com/pytorch/pytorch/blob/d91492a7a40dc5e4531ca5e3abfe86037b6ad8a4/torch/_functorch/aot_autograd.py#L3490-L3492

Aot Autograd always mutates the _parameters dict of the nn modules. So, the natural question is why is it working in the first place? Why doesn't our tests or OSS benchmarks fail?

TorchDynamo guard installation is lazy, i.e., it waits after the backend has finished its compilation. In the case where there is no module sharing between functions getting torch.compiled (as opposed to the example in the description), Dynamo lets AOT Autograd mutate the module, and then guards on the dict version tags. So by the time we install a guard, the module is already mutated and subsequent invocations pass the nn_module_guard check.

In the updated example in the description, the first compilation (w/ c= 0) mutates the module and installs a guard with the updated nn module. But, then c=1 again compiles, AOT autograd again mutates the module, causing the first compilation guard to fail on the next invocation. This starts a chain reaction with each compilation mutating the module and failing the nn module guard of the other function.

There are a few ways to solve it

cc @colesbury

bdhirsh commented 1 year ago

Wow this is a saga.

On the "AOTAutograd is mutating the nn module during compilation", due to the stateless._reparametrize_module call - it feels to me like there are two options:

(1) stateless._reparametrized_module should be smart enough to realize that the old params/buffers are identical to the new ones, so no mutation is necessary

(2) We should figure out a way to create a shallow copy of the dynamo nn module, that we can run stateless._reparametrize_module on. We shouldn't be mutating any state the dynamo is guarding on just to reparametrize for tracing

Actually (1) is not valid - we are changing the params/buffers, since we're swapping them out with FunctionalTensor(FakeTensor) versions of each buffer/param

ezyang commented 1 year ago

Why are we guarding on the Dynamo generated module parameter dict in the first place?

anijain2305 commented 1 year ago

We are not guarding on the Dynamo generated module. We are guarding on the UDF module's inbuilt nn modules, something like L["mod"].linear. We are doing this so that if somebody mutates the submodule (like adds a hook etc) after torch.compile has already run, we can recompile.

Since torch inbuilt modules are not traced by Dynamo, they remain as leaf submodules in the Dynamo graph. AOT autograd is modifying the _parameters of those leaf modules.

ezyang commented 1 year ago

Ok then I am pro Brian's "make a copy of the module". As only well behaved torch.nn modules can get this inline treatment we are guaranteed to be able to do so.

We can also just inline into all nn modules but iirc there were blockers for that

ani300 commented 1 year ago

Just chiming in to confirm we've observed the same issue at IBM and it basically forces recompilation on every single forward for our internal models. Do you know how long a fix for this will take? For now we're stuck running older nightlies without the module guards or just recompiling the newer nightlies with the _parameters guard turned off in C++.

anijain2305 commented 1 year ago

@ani300 we will fix this very soon. Meanwhile you can set this config to False (dont need that C++ change) https://github.com/pytorch/pytorch/blob/7f5737392d637a22d555a88a8546d8fc7ab31084/torch/_dynamo/config.py#L108

@ezyang does this https://github.com/pytorch/pytorch/pull/110230 work? Since Dynamo is not tracing inside torch inbuilt nn modules, we can probably just tell Dynamo not to guard on them as well?

anijain2305 commented 1 year ago

@ani300 Can you check https://github.com/pytorch/pytorch/pull/110230 solves your recompilation problem?

ani300 commented 1 year ago

@anijain2305 I'll test it now!

ani300 commented 1 year ago

@anijain2305 sorry for the late reply, but it's been working!

Chillee commented 1 year ago

Reopening since the underlying issue has not been fixed.

Or wait... has it?