Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.47k stars 3.3k forks source link

Init arguments upon multiple inheritance from LightningDataModule and other parent class #17478

Open TmtStss opened 1 year ago

TmtStss commented 1 year ago

Bug description

When defining a new class inheriting from both lightning.pytorch.LightningDataModule and a user-defined parent class, a TypeError is raised upon initialization of an instance of the new class.

Expected behavior: no TypeError when inheriting from lightning.pytorch.LightningDataModule (see Check2 in the MRE) and docs.

What version are you seeing the problem on?

2.0+

How to reproduce the bug

from lightning.pytorch import LightningDataModule

class Base:
    def __init__(self, foo):
        print(foo)

class Other:
    def __init__(self):
        pass

class Check1(Other, Base):
    def __init__(self, config):
        Other.__init__(self)
        Base.__init__(self, **config)

class Check2(LightningDataModule):
    def __init__(self):
        LightningDataModule.__init__(self)

class CustomDataModule(LightningDataModule, Base):
    def __init__(self, config):
        LightningDataModule.__init__(self)
        Base.__init__(self, **config)

if __name__ == "__main__":
    config = {"foo": "foo"}
    Check1(config)
    Check2()
    CustomDataModule(config)

Error messages and logs

Traceback (most recent call last):
  File ".../debug.py", line 35, in <module>
    CustomDataModule(config)
  File ".../debug.py", line 27, in __init__
    LightningDataModule.__init__(self)
  File ".../.pyenv/versions/3.8.10/envs/frm_test/lib/python3.8/site-packages/lightning/pytorch/core/datamodule.py", line 64, in __init__
    super().__init__()
  File .../.pyenv/versions/3.8.10/envs/frm_test/lib/python3.8/site-packages/lightning/pytorch/core/hooks.py", line 298, in __init__
    super().__init__()
  File ".../.pyenv/versions/3.8.10/envs/frm_test/lib/python3.8/site-packages/lightning/pytorch/core/mixins/hparams_mixin.py", line 31, in __init__
    super().__init__()
TypeError: __init__() missing 1 required positional argument: 'foo'

Environment

Current environment ``` * CUDA: - GPU: None - available: False - version: 11.7 * Lightning: - lightning: 2.0.1.post0 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.0 - torch: 2.0.0 - torch-tb-profiler: 0.4.1 - torchmetrics: 0.11.4 - torchsummary: 1.5.1 - torchview: 0.2.6 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - affine: 2.4.0 - aiofiles: 22.1.0 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - aiosqlite: 0.18.0 - alembic: 1.9.2 - anyio: 3.6.2 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asciitree: 0.3.3 - asttokens: 2.0.8 - astunparse: 1.6.3 - async-timeout: 4.0.2 - attrs: 22.1.0 - autopage: 0.5.1 - babel: 2.10.3 - backcall: 0.2.0 - backoff: 2.2.1 - beautifulsoup4: 4.11.1 - bleach: 5.0.1 - blessed: 1.20.0 - bokeh: 2.4.3 - bottle: 0.12.25 - bqplot: 0.12.36 - branca: 0.5.0 - cachetools: 5.2.0 - certifi: 2022.9.24 - cffi: 1.15.1 - cftime: 1.6.2 - charset-normalizer: 2.1.1 - click: 8.1.3 - click-plugins: 1.1.1 - cliff: 4.2.0 - cligj: 0.7.2 - cloudpickle: 2.2.1 - cmaes: 0.9.1 - cmake: 3.26.0 - cmd2: 2.4.3 - colorlog: 6.7.0 - colour: 0.1.5 - comm: 0.1.2 - compress-json: 1.0.8 - contourpy: 1.0.5 - croniter: 1.3.8 - cycler: 0.11.0 - dask: 2022.12.0 - dask-optuna: 0.0.2 - dateutils: 0.6.12 - debugpy: 1.6.3 - decorator: 5.1.1 - deepdiff: 6.3.0 - defusedxml: 0.7.1 - dill: 0.3.6 - distributed: 2022.12.0 - dnspython: 2.3.0 - docopt: 0.6.2 - earthengine-api: 0.1.343 - ee-extra: 0.0.14 - eerepr: 0.0.4 - email-validator: 1.3.1 - entrypoints: 0.4 - executing: 1.1.1 - fastapi: 0.88.0 - fasteners: 0.18 - fastjsonschema: 2.16.2 - ffmpeg-python: 0.2.0 - filelock: 3.8.0 - fiona: 1.8.22 - flask: 2.2.2 - flatbuffers: 23.1.21 - folium: 0.13.0 - fonttools: 4.38.0 - fqdn: 1.5.1 - frarssmap: 0.0.0 - frozenlist: 1.3.3 - fsspec: 2023.1.0 - future: 0.18.2 - gast: 0.4.0 - gdown: 4.5.3 - geeadd: 0.5.6 - geemap: 0.17.1 - geocoder: 1.38.1 - geojson: 2.5.0 - geopandas: 0.10.2 - google-api-core: 2.10.2 - google-api-python-client: 1.12.11 - google-auth: 2.17.2 - google-auth-httplib2: 0.1.0 - google-auth-oauthlib: 0.4.6 - google-cloud-core: 2.3.2 - google-cloud-storage: 2.5.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 2.4.0 - googleapis-common-protos: 1.56.4 - greenlet: 2.0.2 - grip: 4.6.1 - grpcio: 1.51.1 - h11: 0.14.0 - h3: 3.7.4 - h3pandas: 0.2.3 - h5py: 3.8.0 - heapdict: 1.0.1 - httpcore: 0.16.3 - httplib2: 0.20.4 - httplib2shim: 0.0.3 - httptools: 0.5.0 - httpx: 0.23.3 - idna: 3.4 - importlib-metadata: 5.0.0 - importlib-resources: 5.10.0 - inquirer: 3.1.3 - ipyevents: 2.0.1 - ipyfilechooser: 0.6.0 - ipykernel: 6.16.2 - ipyleaflet: 0.17.2 - ipython: 8.5.0 - ipython-genutils: 0.2.0 - ipytree: 0.2.2 - ipywidgets: 7.6.5 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jedi: 0.18.1 - jinja2: 3.1.2 - joblib: 1.2.0 - json5: 0.9.10 - jsonpointer: 2.3 - jsonschema: 4.16.0 - jupyter: 1.0.0 - jupyter-client: 7.1.0 - jupyter-console: 6.4.4 - jupyter-contrib-core: 0.4.2 - jupyter-contrib-nbextensions: 0.5.1 - jupyter-core: 4.9.1 - jupyter-events: 0.5.0 - jupyter-highlight-selected-word: 0.2.0 - jupyter-latex-envs: 1.4.6 - jupyter-nbextensions-configurator: 0.6.1 - jupyter-server: 1.13.1 - jupyter-server-fileid: 0.6.0 - jupyter-server-terminals: 0.4.4 - jupyter-server-ydoc: 0.6.1 - jupyter-ydoc: 0.2.2 - jupyterlab: 3.2.5 - jupyterlab-pygments: 0.1.2 - jupyterlab-server: 2.10.2 - jupyterlab-sublime: 0.4.1 - jupyterlab-widgets: 1.0.2 - kaleido: 0.2.1 - keras: 2.11.0 - kiwisolver: 1.4.4 - libclang: 15.0.6.1 - lightgbm: 3.3.5 - lightning: 2.0.1.post0 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - lit: 16.0.0 - locket: 1.0.0 - logzero: 1.7.0 - lxml: 4.9.2 - mako: 1.2.4 - markdown: 3.4.1 - markdown-it-py: 2.2.0 - markupsafe: 2.1.1 - matplotlib: 3.6.3 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 0.8.4 - mpmath: 1.3.0 - msgpack: 1.0.4 - multidict: 6.0.4 - multiprocess: 0.70.14 - munch: 2.5.0 - nbclassic: 0.3.4 - nbclient: 0.5.9 - nbconvert: 6.4.0 - nbformat: 5.1.3 - nest-asyncio: 1.5.6 - netcdf4: 1.6.2 - netron: 6.5.5 - networkx: 3.0 - notebook: 6.1.5 - notebook-shim: 0.2.0 - numcodecs: 0.11.0 - numpy: 1.23.4 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - opt-einsum: 3.3.0 - optuna: 3.1.0 - optuna-dashboard: 0.9.0 - ordered-set: 4.1.0 - orjson: 3.8.8 - packaging: 21.3 - pandas: 1.5.1 - pandocfilters: 1.5.0 - parso: 0.8.3 - partd: 1.3.0 - path-and-address: 2.0.1 - pathos: 0.3.0 - pbr: 5.11.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.2.0 - pip: 21.1.1 - pipreqs: 0.4.11 - pkgutil-resolve-name: 1.3.10 - platformdirs: 2.6.2 - plot-keras-history: 1.1.38 - plotly: 5.13.1 - pox: 0.3.2 - ppft: 1.7.6.6 - prettytable: 3.6.0 - prometheus-client: 0.15.0 - prompt-toolkit: 3.0.31 - protobuf: 3.19.6 - psutil: 5.9.3 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py: 1.11.0 - pyarrow: 6.0.1 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pycrs: 1.0.2 - pydantic: 1.10.7 - pygments: 2.13.0 - pyjwt: 2.6.0 - pyparsing: 3.0.9 - pyperclip: 1.8.2 - pyproj: 3.4.0 - pyrsistent: 0.18.1 - pyshp: 2.3.1 - pysocks: 1.7.1 - python-box: 6.0.2 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-json-logger: 2.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.0 - pytz: 2022.5 - pyyaml: 6.0 - pyzmq: 24.0.1 - qtconsole: 5.4.0 - qtpy: 2.3.0 - rasterio: 1.3.5 - ratelim: 0.1.6 - readchar: 4.0.5 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - requests-toolbelt: 0.10.1 - restee: 0.0.3 - retry: 0.9.2 - retrying: 1.3.3 - rfc3339-validator: 0.1.4 - rfc3986: 1.5.0 - rfc3986-validator: 0.1.1 - rich: 13.3.3 - rsa: 4.9 - ruamel.yaml: 0.17.21 - ruamel.yaml.clib: 0.2.7 - sanitize-ml-labels: 1.0.50 - sankee: 0.2.0 - scikeras: 0.10.0 - scikit-learn: 1.1.3 - scipy: 1.9.3 - scooby: 0.7.0 - seaborn: 0.12.2 - send2trash: 1.8.0 - setuptools: 56.0.0 - shapely: 1.8.5.post1 - six: 1.16.0 - sklearn-pandas: 2.2.0 - sniffio: 1.3.0 - snuggs: 1.4.7 - sortedcontainers: 2.4.0 - soupsieve: 2.3.2.post1 - sqlalchemy: 2.0.2 - stack-data: 0.5.1 - starlette: 0.22.0 - starsessions: 1.3.0 - stevedore: 5.0.0 - support-developer: 1.0.5 - sympy: 1.11.1 - tb-nightly: 2.13.0a20230412 - tblib: 1.7.0 - tenacity: 8.1.0 - tensorboard: 2.11.2 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorflow: 2.11.0 - tensorflow-estimator: 2.11.0 - tensorflow-io-gcs-filesystem: 0.28.0 - termcolor: 2.0.1 - terminado: 0.17.0 - testpath: 0.6.0 - threadpoolctl: 3.1.0 - tinycss2: 1.2.1 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.0.0 - torch-tb-profiler: 0.4.1 - torchmetrics: 0.11.4 - torchsummary: 1.5.1 - torchview: 0.2.6 - torchvision: 0.15.1 - tornado: 6.2 - tqdm: 4.64.1 - traitlets: 5.5.0 - traittypes: 0.2.1 - triton: 2.0.0 - typing-extensions: 4.4.0 - ujson: 5.7.0 - uri-template: 1.2.0 - uritemplate: 3.0.1 - urllib3: 1.26.12 - uvicorn: 0.21.1 - uvloop: 0.17.0 - watchfiles: 0.19.0 - wcwidth: 0.2.5 - webcolors: 1.12 - webencodings: 0.5.1 - websocket-client: 1.4.1 - websockets: 10.4 - werkzeug: 2.2.2 - wheel: 0.38.4 - whitebox: 2.2.0 - whiteboxgui: 2.2.0 - widgetsnbextension: 3.5.2 - wrapt: 1.14.1 - wxee: 0.3.3 - xarray: 2022.11.0 - xgboost: 1.7.5 - xyzservices: 2022.9.0 - y-py: 0.5.5 - yarg: 0.1.9 - yarl: 1.8.2 - ypy-websocket: 0.8.2 - zarr: 2.10.3 - zict: 2.2.0 - zipp: 3.10.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.10 - version: #1 SMP Fri Jan 27 02:56:13 UTC 2023 ```

More info

No response

cc @carmocca @awaelchli @borda

awaelchli commented 1 year ago

We could try to replace the super().__init__() call in LightningDataModule with individual super calls to the two mixins it inherits from. I'm not sure if that's the correct solution.

bkiat1123 commented 1 year ago

If you prefer to have finer control on your custom class and LightningDataModule init method, it might make sense to inherit custom class before LightningDataModule.

The current way you define it put your custom class high in the chain of method resolution order.

from lightning.pytorch import LightningDataModule

class Base:
    def __init__(self, foo):
        print(foo)

class CustomDataModule(LightningDataModule, Base):
    def __init__(self, config):
        LightningDataModule.__init__(self)
        Base.__init__(self, **config)

CustomDataModule.__mro__

(__main__.CustomDataModule,
 lightning.pytorch.core.datamodule.LightningDataModule,
 lightning.pytorch.core.hooks.DataHooks,
 lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin,
 __main__.Base,
 object)

To make it works, we have to make all parents in LightningDataModule take arbitary **kwargs in __init__. You need to pass your config via super().__init__, without the ability to explicit control on what args go to which __init__ method.

 class HyperparametersMixin:

    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

class DataHooks:
    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

class LightningDataModule(DataHooks, HyperparametersMixin):
    def __init__(self, **kwargs) -> None:
        super().__init__(**kwargs)

class Base:
    def __init__(self, foo):
        super().__init__()

class CustomDataModule(LightningDataModule, Base):
    def __init__(self, config):
        super().__init__(**config)

Putting your custom class first give you finer control on the __init__.

class Base:
    def __init__(self, foo):
        print(foo)

class CustomDataModule(Base, LightningDataModule):
    def __init__(self, config):
        Base.__init__(self, **config)
        LightningDataModule.__init__(self)

CustomDataModule.__mro__
(__main__.CustomDataModule,
 __main__.Base,
 __main__.LightningDataModule,
 __main__.DataHooks,
 __main__.HyperparametersMixin,
 object)

config = {"foo": "foo"}
CustomDataModule(config)

Hope it helps.