Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.31k stars 3.38k forks source link

Recursion depth exceeded with custom `__getattr__` on `torch.nn.module` #19307

Closed jyothisambolu closed 8 months ago

jyothisambolu commented 9 months ago

Bug description

Python runtime throws an exception when the Fabric wraps the torch.nn.Module whose getattr is overriden.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

from tests_fabric.helpers.models import BoringFabric
from typing import Union, Any

def wrapped__getattr__(self,name:str) -> Union[torch.Tensor, torch.nn.Module]:
    result = self.original__get_attr__(name)
    try:
       print(f"Inside custom getattr on torch module for : {name}")
       self._custom_attr ="Custom"
    except Exception as e:
       print(f"Exception occured: {e}")

    return result

torch.nn.modules.Module.original__get_attr__ = torch.nn.modules.Module.__getattr__
torch.nn.modules.Module.__getattr__ = wrapped__getattr__

def test_wrapper():
    fabric = BoringFabric()
    fabric.expected_dtype="bf16-mixed"
    fabric.run()

Error messages and logs

strategies/test_single_device.py::test_wrapper FAILED

=============================================================================== FAILURES ===============================================================================
_____________________________________________________________________________ test_wrapper _____________________________________________________________________________

    def test_wrapper():
        fabric = BoringFabric()
        fabric.expected_dtype="bf16-mixed"
>       fabric.run()

strategies/test_single_device.py:230:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../../../../venv/lib/python3.10/site-packages/lightning/fabric/fabric.py:925: in _wrap_and_launch
    return to_run(*args, **kwargs)
../../../../venv/lib/python3.10/site-packages/lightning/fabric/fabric.py:930: in _wrap_with_setup
    return to_run(*args, **kwargs)
helpers/models.py:56: in run
    model = self.get_model()
helpers/models.py:36: in get_model
    return nn.Linear(32, 2)
../../../../venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:96: in __init__
    self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
../../../../venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1715: in __setattr__
    self.register_parameter(name, value)
../../../../venv/lib/python3.10/site-packages/torch/nn/modules/module.py:577: in register_parameter
    elif hasattr(self, name) and name not in self._parameters:
strategies/test_single_device.py:214: in wrapped__getattr__
    result = self.original__get_attr__(name)
strategies/test_single_device.py:173: in wrapped__getattr__
    result = self.original__get_attr__(name)
strategies/test_single_device.py:173: in wrapped__getattr__
    result = self.original__get_attr__(name)
E   RecursionError: maximum recursion depth exceeded
!!! Recursion detected (same locals & position)
======================================================================= short test summary info ========================================================================
FAILED strategies/test_single_device.py::test_wrapper - RecursionError: maximum recursion depth exceeded

Environment

* CUDA:
        - GPU:               None
        - available:         False
        - version:           None
* Lightning:
        - lightning:         2.1.3
        - lightning-cloud:   0.5.57
        - lightning-habana:  1.3.0
        - lightning-utilities: 0.10.0
        - pytorch-lightning: 2.1.3

cc @carmocca @justusschock @awaelchli

awaelchli commented 9 months ago

@jyothisambolu This is because when you do

torch.nn.modules.Module.original__get_attr__ = torch.nn.modules.Module.__getattr__
torch.nn.modules.Module.__getattr__ = wrapped__getattr__

You are overriding the getattr on every module, including the FabricModule. In your custom getattr, when you do self._custom_attr ="Custom" it is calling __setattr__ in FabricModule, and this calls __getattr__ again, which then creates this loop.

My suggestion is avoid overriding getattr for all nn.Modules. Override it only on the models you care about. Example:

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(2, 2)
        ...

MyModel.original__get_attr__ = MyModel.__getattr__
MyModel.__getattr__ = wrapped__getattr__
awaelchli commented 9 months ago

Hey @jyothisambolu can you take a look at my reply?

jyothisambolu commented 9 months ago

Hey @jyothisambolu can you take a look at my reply?

Hi @awaelchli, Thanks for the solution. It may work for model-specific customizations. But if we want to use custom attr across all modules( for module debug/analysis/customization) we will still hit the issue.

awaelchli commented 9 months ago

I don't know a solution to this at the moment. The implementation of the getattr and setattr on the FabricModule are quite essential. I don't know how to change them to support your use case unfortunately.

lantiga commented 9 months ago

I think the right way to go here is to check isinstance(self, nn.Module) and not isinstance(self, FabricModule) when you override @jyothisambolu, and if so apply the instrumentation otherwise fallback to the standard getattr.