Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.97k stars 3.35k forks source link

`link_arguments` does not work in lightning 2.3 #20147

Open peacekurella opened 1 month ago

peacekurella commented 1 month ago

Bug description

When using parser.link_arguments to link fields a & b with apply_on="instantiate", it does not populate the field b when it is accessed later. This was not a problem in lightning 2.2.5 as we are using it currently. However upgrading it to 2.3.x causes field b to not be populated.

What version are you seeing the problem on?

2.3.3

How to reproduce the bug

https://github.com/Lightning-AI/pytorch-lightning/issues/20147#issuecomment-2266215234

Error messages and logs

-

Environment

Current environment * CUDA: - GPU: None - available: False - version: 12.1 * Lightning: - lightning: 2.2.5 - lightning-utilities: 0.11.2 - pytorch-lightning: 2.3.0 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 * Packages: - aiobotocore: 2.7.0 - aiohttp: 3.9.5 - aioitertools: 0.7.1 - aiosignal: 1.2.0 - alabaster: 0.7.16 - altair: 5.0.1 - anyio: 4.2.0 - appdirs: 1.4.4 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - astroid: 2.14.2 - astropy: 6.1.0 - astropy-iers-data: 0.2024.6.3.0.31.14 - asttokens: 2.0.5 - async-lru: 2.0.4 - async-timeout: 4.0.3 - atomicwrites: 1.4.0 - attrs: 23.1.0 - automat: 20.2.0 - autopep8: 2.0.4 - babel: 2.11.0 - bcrypt: 3.2.0 - beautifulsoup4: 4.12.3 - binaryornot: 0.4.4 - black: 24.4.2 - bleach: 4.1.0 - blinker: 1.6.2 - bokeh: 3.4.1 - boto3: 1.34.131 - botocore: 1.34.131 - bottleneck: 1.3.7 - brotli: 1.0.9 - cachetools: 5.3.3 - cattrs: 23.2.3 - certifi: 2024.6.2 - cffi: 1.16.0 - chardet: 4.0.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - cloudpickle: 2.2.1 - colorama: 0.4.6 - colorcet: 3.1.0 - comm: 0.2.1 - constantly: 23.10.4 - contourpy: 1.2.0 - cookiecutter: 2.6.0 - cryptography: 42.0.5 - cssselect: 1.2.0 - cycler: 0.11.0 - cytoolz: 0.12.2 - dask: 2024.5.0 - dask-expr: 1.1.0 - datasets: 2.14.6 - datashader: 0.16.2 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - diff-match-patch: 20200713 - dill: 0.3.7 - distributed: 2024.5.0 - docker: 7.1.0 - docstring-parser: 0.16 - docstring-to-markdown: 0.11 - docutils: 0.18.1 - entrypoints: 0.4 - et-xmlfile: 1.1.0 - exceptiongroup: 1.2.0 - executing: 0.8.3 - fastjsonschema: 2.16.2 - filelock: 3.13.1 - flake8: 7.0.0 - flask: 3.0.3 - fonttools: 4.51.0 - frozenlist: 1.4.0 - fsspec: 2023.10.0 - gensim: 4.3.2 - gitdb: 4.0.7 - gitpython: 3.1.37 - gmpy2: 2.1.2 - google-pasta: 0.2.0 - greenlet: 3.0.1 - h5py: 3.11.0 - heapdict: 1.0.1 - holoviews: 1.19.0 - huggingface-hub: 0.23.4 - hvplot: 0.10.0 - hyperlink: 21.0.0 - idna: 3.7 - imagecodecs: 2023.1.23 - imageio: 2.33.1 - imagesize: 1.4.1 - imbalanced-learn: 0.12.3 - importlib-metadata: 6.11.0 - importlib-resources: 6.4.0 - incremental: 22.10.0 - inflection: 0.5.1 - iniconfig: 1.1.1 - intake: 0.7.0 - intervaltree: 3.1.0 - ipykernel: 6.28.0 - ipython: 8.25.0 - ipython-genutils: 0.2.0 - ipywidgets: 7.6.5 - isort: 5.13.2 - itemadapter: 0.3.0 - itemloaders: 1.1.0 - itsdangerous: 2.2.0 - jaraco.classes: 3.2.1 - jedi: 0.18.1 - jeepney: 0.7.1 - jellyfish: 1.0.1 - jinja2: 3.1.4 - jmespath: 1.0.1 - joblib: 1.4.2 - json5: 0.9.6 - jsonargparse: 4.30.0 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.7.1 - jupyter: 1.0.0 - jupyter-client: 8.6.0 - jupyter-console: 6.6.3 - jupyter-core: 5.5.0 - jupyter-events: 0.10.0 - jupyter-lsp: 2.2.0 - jupyter-server: 2.10.0 - jupyter-server-terminals: 0.4.4 - jupyterlab: 4.0.11 - jupyterlab-pygments: 0.1.2 - jupyterlab-server: 2.25.1 - jupyterlab-widgets: 3.0.10 - keyring: 24.3.1 - kiwisolver: 1.4.4 - klon: 2.3.0 - lazy-loader: 0.4 - lazy-object-proxy: 1.10.0 - lckr-jupyterlab-variableinspector: 3.1.0 - lightning: 2.2.5 - lightning-utilities: 0.11.2 - linkify-it-py: 2.0.0 - llvmlite: 0.42.0 - lmdb: 1.4.1 - locket: 1.0.0 - lsprotocol: 2023.0.1 - lxml: 4.9.4 - lxml-stubs: 0.1.1 - lz4: 4.3.2 - markdown: 3.4.1 - markdown-it-py: 2.2.0 - markupsafe: 2.1.3 - matplotlib: 3.8.4 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdit-py-plugins: 0.3.0 - mdurl: 0.1.0 - mistune: 2.0.4 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - more-itertools: 10.1.0 - mpmath: 1.3.0 - msgpack: 1.0.3 - multidict: 6.0.4 - multipledispatch: 0.6.0 - multiprocess: 0.70.15 - mypy: 1.10.0 - mypy-extensions: 1.0.0 - nbclient: 0.8.0 - nbconvert: 7.10.0 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.2.1 - nltk: 3.8.1 - notebook: 7.0.8 - notebook-shim: 0.2.3 - numba: 0.59.1 - numexpr: 2.8.7 - numpy: 1.26.4 - numpydoc: 1.7.0 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.20.5 - nvidia-nvjitlink-cu12: 12.5.40 - nvidia-nvtx-cu12: 12.1.105 - openpyxl: 3.1.2 - overrides: 7.4.0 - packaging: 23.2 - pandas: 2.2.2 - pandocfilters: 1.5.0 - panel: 1.4.4 - param: 2.1.0 - parsel: 1.8.1 - parso: 0.8.3 - partd: 1.4.1 - pathos: 0.3.1 - pathspec: 0.10.3 - patsy: 0.5.6 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.3.0 - pip: 24.0 - platformdirs: 3.10.0 - plotly: 5.22.0 - pluggy: 1.5.0 - ply: 3.11 - pox: 0.3.4 - ppft: 1.7.6.8 - prometheus-client: 0.14.1 - prompt-toolkit: 3.0.43 - protego: 0.1.16 - protobuf: 3.20.3 - psutil: 5.9.0 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - py-cpuinfo: 9.0.0 - pyarrow: 14.0.2 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycodestyle: 2.11.1 - pycparser: 2.21 - pyct: 0.5.0 - pycurl: 7.45.2 - pydeck: 0.8.0 - pydispatcher: 2.0.5 - pydocstyle: 6.3.0 - pyerfa: 2.0.1.4 - pyflakes: 3.2.0 - pygls: 1.3.1 - pygments: 2.15.1 - pylint: 2.16.2 - pylint-venv: 3.0.3 - pyls-spyder: 0.4.0 - pyodbc: 5.0.1 - pyopenssl: 24.0.0 - pyparsing: 3.0.9 - pyproj: 3.6.1 - pyqt5: 5.15.10 - pyqt5-sip: 12.13.0 - pyqtwebengine: 5.15.6 - pysocks: 1.7.1 - pytest: 8.2.2 - python-dateutil: 2.9.0.post0 - python-json-logger: 2.0.7 - python-lsp-black: 2.0.0 - python-lsp-jsonrpc: 1.1.2 - python-lsp-server: 1.10.0 - python-slugify: 5.0.2 - python-snappy: 0.6.1 - pytoolconfig: 1.2.6 - pytorch-lightning: 2.3.0 - pytz: 2024.1 - pyviz-comms: 3.0.2 - pywavelets: 1.5.0 - pyxdg: 0.27 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - qdarkstyle: 3.2.3 - qstylizer: 0.2.2 - qtawesome: 1.2.2 - qtconsole: 5.5.1 - qtpy: 2.4.1 - queuelib: 1.6.2 - referencing: 0.30.2 - regex: 2023.10.3 - requests: 2.32.3 - requests-file: 1.5.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rich: 13.3.5 - rope: 1.12.0 - rpds-py: 0.10.6 - rtree: 1.0.1 - ruff: 0.4.9 - ruff-lsp: 0.0.53 - s3fs: 2023.10.0 - s3transfer: 0.10.1 - sagemaker: 2.224.0 - schema: 0.7.7 - scikit-image: 0.23.2 - scikit-learn: 1.4.2 - scipy: 1.11.4 - scrapy: 2.11.1 - seaborn: 0.13.2 - secretstorage: 3.3.1 - send2trash: 1.8.2 - service-identity: 18.1.0 - setuptools: 69.5.1 - sip: 6.7.12 - six: 1.16.0 - smart-open: 5.2.1 - smdebug-rulesconfig: 1.0.1 - smmap: 4.0.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - sortedcontainers: 2.4.0 - soupsieve: 2.5 - sphinx: 7.3.7 - sphinxcontrib-applehelp: 1.0.2 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.0 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.10 - spyder: 5.5.1 - spyder-kernels: 2.5.0 - sqlalchemy: 2.0.30 - stack-data: 0.2.0 - statsmodels: 0.14.2 - streamlit: 1.32.0 - sympy: 1.12 - tables: 3.9.2 - tabulate: 0.9.0 - tblib: 1.7.0 - tenacity: 8.2.2 - tensorboardx: 2.6.2.2 - terminado: 0.17.1 - text-unidecode: 1.3 - textdistance: 4.2.1 - threadpoolctl: 2.2.0 - three-merge: 0.1.1 - tifffile: 2023.4.12 - tinycss2: 1.2.1 - tldextract: 3.2.0 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.1 - toolz: 0.12.0 - torch: 2.3.1 - torchmetrics: 1.4.0.post0 - tornado: 6.4.1 - tqdm: 4.66.4 - traitlets: 5.14.3 - triton: 2.3.1 - twisted: 23.10.0 - typeshed-client: 2.5.1 - typing-extensions: 4.11.0 - tzdata: 2023.3 - uc-micro-py: 1.0.1 - ujson: 5.10.0 - unicodedata2: 15.1.0 - unidecode: 1.2.0 - urllib3: 2.0.7 - w3lib: 2.1.2 - watchdog: 4.0.1 - wcwidth: 0.2.5 - webencodings: 0.5.1 - websocket-client: 1.8.0 - werkzeug: 3.0.3 - whatthepatch: 1.0.2 - wheel: 0.43.0 - widgetsnbextension: 3.5.2 - wrapt: 1.14.1 - wurlitzer: 3.0.2 - xarray: 2023.6.0 - xxhash: 3.4.1 - xyzservices: 2022.9.0 - yapf: 0.40.2 - yarl: 1.9.3 - zict: 3.0.0 - zipp: 3.17.0 - zope.interface: 5.4.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.13 - release: 5.10.220-188.869.amzn2int.x86_64 - version: #1 SMP Wed Jul 17 14:39:49 UTC 2024

More info

No response

cc @carmocca @mauvilsa

peacekurella commented 1 month ago

I noticed that the drop down menu does not contain 2.3.x as part of the version selection.

awaelchli commented 1 month ago

Hey @peacekurella can you please provide a code example based on https://github.com/Lightning-AI/pytorch-lightning/blob/master/examples/pytorch/bug_report/bug_report_model.py so we can verify it's not working?

peacekurella commented 1 month ago

I can do that.

peacekurella commented 1 month ago
import torch
from typing import Type, TypeVar
from lightning.pytorch import LightningModule
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch.cli import LightningCLI
from lightning import LightningDataModule
from lightning.pytorch.callbacks import ModelCheckpoint

class MyLightningCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):

        parser.add_argument("data.destinationaddressid_vocab_size", default=10)
        parser.add_argument("model.destinationaddressid_vocab_size")
        parser.add_argument("--ckpt_path_ex", type=str, default = None)

        parser.link_arguments(
            "data.destinationaddressid_vocab_size",
            "model.destinationaddressid_vocab_size",
            apply_on="instantiate",
        )

    def before_instantiate_classes(self) -> None:
        if self.config.ckpt_path_ex:
            print("restoring from checkpoint")
            # we are restoring from a checkpoint
            CheckpointModuleInstantiatiorCLI.before_instantiate_classes(self)

class MyDataModule(LightningDataModule):
    def __init__(self, destinationaddressid_vocab_size: int = None):
        super().__init__()
        self.destinationaddressid_vocab_size = destinationaddressid_vocab_size
        print(f"The value of destinationaddressid_vocab_size in data module is {destinationaddressid_vocab_size}")

    def train_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def predict_dataloader(self):
        return DataLoader(RandomDataset(32, 64), batch_size=2)

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

ModuleType = TypeVar("ModuleType")

class CheckpointModuleInstantiatiorCLI:
    def __init__(self, cli: LightningCLI):
        self.cli = cli

    def class_instantiator(self, class_type: Type[ModuleType], *args, **kwargs) -> ModuleType:
        if args:
            raise ValueError("Unexpected args")

        map_location = None if torch.cuda.is_available() else "cpu"
        defaults = self.cli.parser.get_defaults()
        if class_type == BoringModel:
            non_default_kwargs = {k: v for k, v in kwargs.items() if defaults.model.get(k) != v}
            return BoringModel.load_from_checkpoint(
                self.cli.config.ckpt_path_ex,
                map_location=map_location,
                **non_default_kwargs,
            )
        elif class_type == MyDataModule:
            non_default_kwargs = {k: v for k, v in kwargs.items() if defaults.data.get(k) != v}
            return MyDataModule.load_from_checkpoint(
                self.cli.config.ckpt_path_ex,
                map_location=map_location,
                **non_default_kwargs,
            )
        else:
            raise ValueError("Unexpected class")

    @staticmethod
    def before_instantiate_classes(cli: LightningCLI) -> None:
        instantiator = CheckpointModuleInstantiatiorCLI(cli)
        cli.parser.add_instantiator(instantiator.class_instantiator, BoringModel)
        cli.parser.add_instantiator(instantiator.class_instantiator, MyDataModule)

class BoringModel(LightningModule):
    def __init__(self, destinationaddressid_vocab_size):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)
        self.destinationaddressid_vocab_size = destinationaddressid_vocab_size
        self.save_hyperparameters()
        print(f"The value of destinationaddressid_vocab_size in model module is {self.destinationaddressid_vocab_size}")

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run(args):

    cli = MyLightningCLI(
        BoringModel,
        MyDataModule,
        args=args,
        trainer_defaults={"callbacks": [ModelCheckpoint(dirpath="ckpts")]},
        run=False,
    )

    cli.trainer.fit(
        model=cli.model,
        datamodule=cli.datamodule,
        ckpt_path=cli.config.ckpt_path_ex if cli.config.ckpt_path_ex else None,
    )

if __name__ == "__main__":
    run(args=None)

Running with lightning 2.2.5

  1. generate checkpoints python bug_report.py --data.destinationaddressid_vocab_size 15 --trainer.max_epoch=1 . This prints

    The value of destinationaddressid_vocab_size in data module is 15
    The value of destinationaddressid_vocab_size in model module is 15
  2. load the model from checkpoints python bug_report.py --trainer.max_epoch=2 --ckpt_path_ex ckpts/epoch=0-step=32.ckpt this prints

    The value of destinationaddressid_vocab_size in data module is None
    The value of destinationaddressid_vocab_size in model module is 15

Running with lightning 2.3.3

  1. generate checkpoints python bug_report.py --data.destinationaddressid_vocab_size 15 --trainer.max_epoch=1 . This prints

    The value of destinationaddressid_vocab_size in data module is 15
    The value of destinationaddressid_vocab_size in model module is 15
  2. load the model from checkpoints python bug_report.py --trainer.max_epoch=2 --ckpt_path_ex ckpts/epoch=0-step=32.ckpt this prints

    The value of destinationaddressid_vocab_size in data module is 10
    The value of destinationaddressid_vocab_size in model module is 10
peacekurella commented 1 month ago

@awaelchli added the repro code and scenarios with outputs.

awaelchli commented 1 month ago

Ok thanks @peacekurella. But the default value is 10, and in the second command you don't pass --data.destinationaddressid_vocab_size 15. When you resume training, you certainly would need to pass the same configuration. We can't expect that the output is 15 in the second example, if data.destinationaddressid_vocab_size is not passed.

peacekurella commented 1 month ago

@awaelchli The way I understand it, save_hyperparameters() is not storing the values for parameters that have been linked previously. This was not the case in lightning 2.2.5. This is a problem when restoring from ckpt files for inference. Typically we try to get all the required HP for inference from the ckpt file itself.