Open TmtStss opened 1 year ago
We could try to replace the super().__init__()
call in LightningDataModule with individual super calls to the two mixins it inherits from. I'm not sure if that's the correct solution.
If you prefer to have finer control on your custom class and LightningDataModule init method, it might make sense to inherit custom class before LightningDataModule.
The current way you define it put your custom class high in the chain of method resolution order.
from lightning.pytorch import LightningDataModule
class Base:
def __init__(self, foo):
print(foo)
class CustomDataModule(LightningDataModule, Base):
def __init__(self, config):
LightningDataModule.__init__(self)
Base.__init__(self, **config)
CustomDataModule.__mro__
(__main__.CustomDataModule,
lightning.pytorch.core.datamodule.LightningDataModule,
lightning.pytorch.core.hooks.DataHooks,
lightning.pytorch.core.mixins.hparams_mixin.HyperparametersMixin,
__main__.Base,
object)
To make it works, we have to make all parents in LightningDataModule take arbitary **kwargs
in __init__
. You need to pass your config via super().__init__
, without the ability to explicit control on what args go to which __init__
method.
class HyperparametersMixin:
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
class DataHooks:
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
class LightningDataModule(DataHooks, HyperparametersMixin):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
class Base:
def __init__(self, foo):
super().__init__()
class CustomDataModule(LightningDataModule, Base):
def __init__(self, config):
super().__init__(**config)
Putting your custom class first give you finer control on the __init__
.
class Base:
def __init__(self, foo):
print(foo)
class CustomDataModule(Base, LightningDataModule):
def __init__(self, config):
Base.__init__(self, **config)
LightningDataModule.__init__(self)
CustomDataModule.__mro__
(__main__.CustomDataModule,
__main__.Base,
__main__.LightningDataModule,
__main__.DataHooks,
__main__.HyperparametersMixin,
object)
config = {"foo": "foo"}
CustomDataModule(config)
Hope it helps.
Bug description
When defining a new class inheriting from both
lightning.pytorch.LightningDataModule
and a user-defined parent class, a TypeError is raised upon initialization of an instance of the new class.Expected behavior: no TypeError when inheriting from
lightning.pytorch.LightningDataModule
(see Check2 in the MRE) and docs.What version are you seeing the problem on?
2.0+
How to reproduce the bug
Error messages and logs
Environment
Current environment
``` * CUDA: - GPU: None - available: False - version: 11.7 * Lightning: - lightning: 2.0.1.post0 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.0 - torch: 2.0.0 - torch-tb-profiler: 0.4.1 - torchmetrics: 0.11.4 - torchsummary: 1.5.1 - torchview: 0.2.6 - torchvision: 0.15.1 * Packages: - absl-py: 1.4.0 - affine: 2.4.0 - aiofiles: 22.1.0 - aiohttp: 3.8.4 - aiosignal: 1.3.1 - aiosqlite: 0.18.0 - alembic: 1.9.2 - anyio: 3.6.2 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asciitree: 0.3.3 - asttokens: 2.0.8 - astunparse: 1.6.3 - async-timeout: 4.0.2 - attrs: 22.1.0 - autopage: 0.5.1 - babel: 2.10.3 - backcall: 0.2.0 - backoff: 2.2.1 - beautifulsoup4: 4.11.1 - bleach: 5.0.1 - blessed: 1.20.0 - bokeh: 2.4.3 - bottle: 0.12.25 - bqplot: 0.12.36 - branca: 0.5.0 - cachetools: 5.2.0 - certifi: 2022.9.24 - cffi: 1.15.1 - cftime: 1.6.2 - charset-normalizer: 2.1.1 - click: 8.1.3 - click-plugins: 1.1.1 - cliff: 4.2.0 - cligj: 0.7.2 - cloudpickle: 2.2.1 - cmaes: 0.9.1 - cmake: 3.26.0 - cmd2: 2.4.3 - colorlog: 6.7.0 - colour: 0.1.5 - comm: 0.1.2 - compress-json: 1.0.8 - contourpy: 1.0.5 - croniter: 1.3.8 - cycler: 0.11.0 - dask: 2022.12.0 - dask-optuna: 0.0.2 - dateutils: 0.6.12 - debugpy: 1.6.3 - decorator: 5.1.1 - deepdiff: 6.3.0 - defusedxml: 0.7.1 - dill: 0.3.6 - distributed: 2022.12.0 - dnspython: 2.3.0 - docopt: 0.6.2 - earthengine-api: 0.1.343 - ee-extra: 0.0.14 - eerepr: 0.0.4 - email-validator: 1.3.1 - entrypoints: 0.4 - executing: 1.1.1 - fastapi: 0.88.0 - fasteners: 0.18 - fastjsonschema: 2.16.2 - ffmpeg-python: 0.2.0 - filelock: 3.8.0 - fiona: 1.8.22 - flask: 2.2.2 - flatbuffers: 23.1.21 - folium: 0.13.0 - fonttools: 4.38.0 - fqdn: 1.5.1 - frarssmap: 0.0.0 - frozenlist: 1.3.3 - fsspec: 2023.1.0 - future: 0.18.2 - gast: 0.4.0 - gdown: 4.5.3 - geeadd: 0.5.6 - geemap: 0.17.1 - geocoder: 1.38.1 - geojson: 2.5.0 - geopandas: 0.10.2 - google-api-core: 2.10.2 - google-api-python-client: 1.12.11 - google-auth: 2.17.2 - google-auth-httplib2: 0.1.0 - google-auth-oauthlib: 0.4.6 - google-cloud-core: 2.3.2 - google-cloud-storage: 2.5.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 2.4.0 - googleapis-common-protos: 1.56.4 - greenlet: 2.0.2 - grip: 4.6.1 - grpcio: 1.51.1 - h11: 0.14.0 - h3: 3.7.4 - h3pandas: 0.2.3 - h5py: 3.8.0 - heapdict: 1.0.1 - httpcore: 0.16.3 - httplib2: 0.20.4 - httplib2shim: 0.0.3 - httptools: 0.5.0 - httpx: 0.23.3 - idna: 3.4 - importlib-metadata: 5.0.0 - importlib-resources: 5.10.0 - inquirer: 3.1.3 - ipyevents: 2.0.1 - ipyfilechooser: 0.6.0 - ipykernel: 6.16.2 - ipyleaflet: 0.17.2 - ipython: 8.5.0 - ipython-genutils: 0.2.0 - ipytree: 0.2.2 - ipywidgets: 7.6.5 - isoduration: 20.11.0 - itsdangerous: 2.1.2 - jedi: 0.18.1 - jinja2: 3.1.2 - joblib: 1.2.0 - json5: 0.9.10 - jsonpointer: 2.3 - jsonschema: 4.16.0 - jupyter: 1.0.0 - jupyter-client: 7.1.0 - jupyter-console: 6.4.4 - jupyter-contrib-core: 0.4.2 - jupyter-contrib-nbextensions: 0.5.1 - jupyter-core: 4.9.1 - jupyter-events: 0.5.0 - jupyter-highlight-selected-word: 0.2.0 - jupyter-latex-envs: 1.4.6 - jupyter-nbextensions-configurator: 0.6.1 - jupyter-server: 1.13.1 - jupyter-server-fileid: 0.6.0 - jupyter-server-terminals: 0.4.4 - jupyter-server-ydoc: 0.6.1 - jupyter-ydoc: 0.2.2 - jupyterlab: 3.2.5 - jupyterlab-pygments: 0.1.2 - jupyterlab-server: 2.10.2 - jupyterlab-sublime: 0.4.1 - jupyterlab-widgets: 1.0.2 - kaleido: 0.2.1 - keras: 2.11.0 - kiwisolver: 1.4.4 - libclang: 15.0.6.1 - lightgbm: 3.3.5 - lightning: 2.0.1.post0 - lightning-cloud: 0.5.32 - lightning-utilities: 0.8.0 - lit: 16.0.0 - locket: 1.0.0 - logzero: 1.7.0 - lxml: 4.9.2 - mako: 1.2.4 - markdown: 3.4.1 - markdown-it-py: 2.2.0 - markupsafe: 2.1.1 - matplotlib: 3.6.3 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mistune: 0.8.4 - mpmath: 1.3.0 - msgpack: 1.0.4 - multidict: 6.0.4 - multiprocess: 0.70.14 - munch: 2.5.0 - nbclassic: 0.3.4 - nbclient: 0.5.9 - nbconvert: 6.4.0 - nbformat: 5.1.3 - nest-asyncio: 1.5.6 - netcdf4: 1.6.2 - netron: 6.5.5 - networkx: 3.0 - notebook: 6.1.5 - notebook-shim: 0.2.0 - numcodecs: 0.11.0 - numpy: 1.23.4 - nvidia-cublas-cu11: 11.10.3.66 - nvidia-cuda-cupti-cu11: 11.7.101 - nvidia-cuda-nvrtc-cu11: 11.7.99 - nvidia-cuda-runtime-cu11: 11.7.99 - nvidia-cudnn-cu11: 8.5.0.96 - nvidia-cufft-cu11: 10.9.0.58 - nvidia-curand-cu11: 10.2.10.91 - nvidia-cusolver-cu11: 11.4.0.1 - nvidia-cusparse-cu11: 11.7.4.91 - nvidia-nccl-cu11: 2.14.3 - nvidia-nvtx-cu11: 11.7.91 - oauthlib: 3.2.2 - opt-einsum: 3.3.0 - optuna: 3.1.0 - optuna-dashboard: 0.9.0 - ordered-set: 4.1.0 - orjson: 3.8.8 - packaging: 21.3 - pandas: 1.5.1 - pandocfilters: 1.5.0 - parso: 0.8.3 - partd: 1.3.0 - path-and-address: 2.0.1 - pathos: 0.3.0 - pbr: 5.11.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.2.0 - pip: 21.1.1 - pipreqs: 0.4.11 - pkgutil-resolve-name: 1.3.10 - platformdirs: 2.6.2 - plot-keras-history: 1.1.38 - plotly: 5.13.1 - pox: 0.3.2 - ppft: 1.7.6.6 - prettytable: 3.6.0 - prometheus-client: 0.15.0 - prompt-toolkit: 3.0.31 - protobuf: 3.19.6 - psutil: 5.9.3 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py: 1.11.0 - pyarrow: 6.0.1 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pycrs: 1.0.2 - pydantic: 1.10.7 - pygments: 2.13.0 - pyjwt: 2.6.0 - pyparsing: 3.0.9 - pyperclip: 1.8.2 - pyproj: 3.4.0 - pyrsistent: 0.18.1 - pyshp: 2.3.1 - pysocks: 1.7.1 - python-box: 6.0.2 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-editor: 1.0.4 - python-json-logger: 2.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.0 - pytz: 2022.5 - pyyaml: 6.0 - pyzmq: 24.0.1 - qtconsole: 5.4.0 - qtpy: 2.3.0 - rasterio: 1.3.5 - ratelim: 0.1.6 - readchar: 4.0.5 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - requests-toolbelt: 0.10.1 - restee: 0.0.3 - retry: 0.9.2 - retrying: 1.3.3 - rfc3339-validator: 0.1.4 - rfc3986: 1.5.0 - rfc3986-validator: 0.1.1 - rich: 13.3.3 - rsa: 4.9 - ruamel.yaml: 0.17.21 - ruamel.yaml.clib: 0.2.7 - sanitize-ml-labels: 1.0.50 - sankee: 0.2.0 - scikeras: 0.10.0 - scikit-learn: 1.1.3 - scipy: 1.9.3 - scooby: 0.7.0 - seaborn: 0.12.2 - send2trash: 1.8.0 - setuptools: 56.0.0 - shapely: 1.8.5.post1 - six: 1.16.0 - sklearn-pandas: 2.2.0 - sniffio: 1.3.0 - snuggs: 1.4.7 - sortedcontainers: 2.4.0 - soupsieve: 2.3.2.post1 - sqlalchemy: 2.0.2 - stack-data: 0.5.1 - starlette: 0.22.0 - starsessions: 1.3.0 - stevedore: 5.0.0 - support-developer: 1.0.5 - sympy: 1.11.1 - tb-nightly: 2.13.0a20230412 - tblib: 1.7.0 - tenacity: 8.1.0 - tensorboard: 2.11.2 - tensorboard-data-server: 0.6.1 - tensorboard-plugin-wit: 1.8.1 - tensorflow: 2.11.0 - tensorflow-estimator: 2.11.0 - tensorflow-io-gcs-filesystem: 0.28.0 - termcolor: 2.0.1 - terminado: 0.17.0 - testpath: 0.6.0 - threadpoolctl: 3.1.0 - tinycss2: 1.2.1 - tomli: 2.0.1 - toolz: 0.12.0 - torch: 2.0.0 - torch-tb-profiler: 0.4.1 - torchmetrics: 0.11.4 - torchsummary: 1.5.1 - torchview: 0.2.6 - torchvision: 0.15.1 - tornado: 6.2 - tqdm: 4.64.1 - traitlets: 5.5.0 - traittypes: 0.2.1 - triton: 2.0.0 - typing-extensions: 4.4.0 - ujson: 5.7.0 - uri-template: 1.2.0 - uritemplate: 3.0.1 - urllib3: 1.26.12 - uvicorn: 0.21.1 - uvloop: 0.17.0 - watchfiles: 0.19.0 - wcwidth: 0.2.5 - webcolors: 1.12 - webencodings: 0.5.1 - websocket-client: 1.4.1 - websockets: 10.4 - werkzeug: 2.2.2 - wheel: 0.38.4 - whitebox: 2.2.0 - whiteboxgui: 2.2.0 - widgetsnbextension: 3.5.2 - wrapt: 1.14.1 - wxee: 0.3.3 - xarray: 2022.11.0 - xgboost: 1.7.5 - xyzservices: 2022.9.0 - y-py: 0.5.5 - yarg: 0.1.9 - yarl: 1.8.2 - ypy-websocket: 0.8.2 - zarr: 2.10.3 - zict: 2.2.0 - zipp: 3.10.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.8.10 - version: #1 SMP Fri Jan 27 02:56:13 UTC 2023 ```More info
No response
cc @carmocca @awaelchli @borda