Closed jyothisambolu closed 8 months ago
@jyothisambolu This is because when you do
torch.nn.modules.Module.original__get_attr__ = torch.nn.modules.Module.__getattr__
torch.nn.modules.Module.__getattr__ = wrapped__getattr__
You are overriding the getattr on every module, including the FabricModule. In your custom getattr, when you do self._custom_attr ="Custom"
it is calling __setattr__
in FabricModule, and this calls __getattr__
again, which then creates this loop.
My suggestion is avoid overriding getattr for all nn.Modules. Override it only on the models you care about. Example:
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Linear(2, 2)
...
MyModel.original__get_attr__ = MyModel.__getattr__
MyModel.__getattr__ = wrapped__getattr__
Hey @jyothisambolu can you take a look at my reply?
Hey @jyothisambolu can you take a look at my reply?
Hi @awaelchli, Thanks for the solution. It may work for model-specific customizations. But if we want to use custom attr across all modules( for module debug/analysis/customization) we will still hit the issue.
I don't know a solution to this at the moment. The implementation of the getattr and setattr on the FabricModule are quite essential. I don't know how to change them to support your use case unfortunately.
I think the right way to go here is to check isinstance(self, nn.Module) and not isinstance(self, FabricModule)
when you override @jyothisambolu, and if so apply the instrumentation otherwise fallback to the standard getattr.
Bug description
Python runtime throws an exception when the Fabric wraps the torch.nn.Module whose getattr is overriden.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
Environment
cc @carmocca @justusschock @awaelchli