Closed anijain2305 closed 1 year ago
cc @Chillee
Also it does not happen with eager
backend, it happens only with aot_eager
backend.
The _parameters dict changes because of this line - https://github.com/pytorch/pytorch/blob/5dcee01c2b89f6bedeef9dd043fd8d6728286582/torch/nn/utils/_named_member_accessor.py#L52 called during the https://github.com/pytorch/pytorch/blob/d91492a7a40dc5e4531ca5e3abfe86037b6ad8a4/torch/_functorch/aot_autograd.py#L3490-L3492
Aot Autograd always mutates the _parameters dict of the nn modules. So, the natural question is why is it working in the first place? Why doesn't our tests or OSS benchmarks fail?
TorchDynamo guard installation is lazy, i.e., it waits after the backend has finished its compilation. In the case where there is no module sharing between functions getting torch.compiled (as opposed to the example in the description), Dynamo lets AOT Autograd mutate the module, and then guards on the dict version tags. So by the time we install a guard, the module is already mutated and subsequent invocations pass the nn_module_guard check.
In the updated example in the description, the first compilation (w/ c= 0) mutates the module and installs a guard with the updated nn module. But, then c=1
again compiles, AOT autograd again mutates the module, causing the first compilation guard to fail on the next invocation. This starts a chain reaction with each compilation mutating the module and failing the nn module guard of the other function.
There are a few ways to solve it
_parameters
and _buffers
cc @colesbury
Wow this is a saga.
On the "AOTAutograd is mutating the nn module during compilation", due to the stateless._reparametrize_module
call - it feels to me like there are two options:
(1) stateless._reparametrized_module
should be smart enough to realize that the old params/buffers are identical to the new ones, so no mutation is necessary
(2) We should figure out a way to create a shallow copy of the dynamo nn module, that we can run stateless._reparametrize_module
on. We shouldn't be mutating any state the dynamo is guarding on just to reparametrize for tracing
Actually (1) is not valid - we are changing the params/buffers, since we're swapping them out with FunctionalTensor(FakeTensor)
versions of each buffer/param
Why are we guarding on the Dynamo generated module parameter dict in the first place?
We are not guarding on the Dynamo generated module. We are guarding on the UDF module's inbuilt nn modules, something like L["mod"].linear
. We are doing this so that if somebody mutates the submodule (like adds a hook etc) after torch.compile has already run, we can recompile.
Since torch inbuilt modules are not traced by Dynamo, they remain as leaf submodules in the Dynamo graph. AOT autograd is modifying the _parameters of those leaf modules.
Ok then I am pro Brian's "make a copy of the module". As only well behaved torch.nn modules can get this inline treatment we are guaranteed to be able to do so.
We can also just inline into all nn modules but iirc there were blockers for that
Just chiming in to confirm we've observed the same issue at IBM and it basically forces recompilation on every single forward for our internal models. Do you know how long a fix for this will take? For now we're stuck running older nightlies without the module guards or just recompiling the newer nightlies with the _parameters
guard turned off in C++.
@ani300 we will fix this very soon. Meanwhile you can set this config to False (dont need that C++ change) https://github.com/pytorch/pytorch/blob/7f5737392d637a22d555a88a8546d8fc7ab31084/torch/_dynamo/config.py#L108
@ezyang does this https://github.com/pytorch/pytorch/pull/110230 work? Since Dynamo is not tracing inside torch inbuilt nn modules, we can probably just tell Dynamo not to guard on them as well?
@ani300 Can you check https://github.com/pytorch/pytorch/pull/110230 solves your recompilation problem?
@anijain2305 I'll test it now!
@anijain2305 sorry for the late reply, but it's been working!
Reopening since the underlying issue has not been fixed.
Or wait... has it?
🐛 Describe the bug
Repro
The failing guard here is __nn_module_guard. The failure reason is that the
_parameters
dict keeps changing its version.However, there is nothing in the model that should change the
mod
_parameters.Error logs
No response
Minified repro
No response
Versions
NA
cc @ezyang @msaroufim @wconstab @bdhirsh @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @chenyang78 @aakhundov @kadeng