Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.32k stars 3.38k forks source link

OptimizerLRScheduler typing does not fit examples #20106

Open MalteEbner opened 3 months ago

MalteEbner commented 3 months ago

Bug description

The return type of LightningModule.configure_optimizers() is OptimizerLRScheduler, see the source code. However, the examples give return types not fitting this return type. See e.g. the example here.

Furthermore, the OptimizerLRScheduler is only used as a return type, but I don't see where it is actually used, i.e. the other part of the typed interface. A search for it does not reveal it.

What version are you seeing the problem on?

2.3.1

How to reproduce the bug

Just run mypy on an example, e.g. https://github.com/Lightning-AI/pytorch-lightning/blob/e6c26d2d22fc4678b2cf47d57697ebae68b09529/examples/fabric/build_your_own_trainer/run.py#L42-L49.

Error messages and logs

Running mypy on this causes a bug:

Error: Incompatible return value type (got "tuple[Adam, dict[str, object]]", expected "Optimizer | Sequence[Optimizer] | tuple[Sequence[Optimizer], Sequence[LRScheduler | ReduceLROnPlateau | LRSchedulerConfig]] | OptimizerLRSchedulerConfig | Sequence[OptimizerLRSchedulerConfig] | None")

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning-utilities: 0.11.3.post0 - pytorch-lightning: 2.3.1 - torch: 2.3.1 - torchmetrics: 0.8.0 - torchvision: 0.18.1 * Packages: - absl-py: 2.1.0 - aenum: 3.1.15 - aiohttp: 3.9.5 - aiosignal: 1.3.1 - alabaster: 0.7.16 - albumentations: 1.3.1 - antlr4-python3-runtime: 4.9.3 - arabic-reshaper: 3.0.0 - asn1crypto: 1.5.1 - async-timeout: 4.0.3 - attrs: 23.2.0 - babel: 2.15.0 - boto3: 1.34.137 - botocore: 1.34.137 - build: 1.2.1 - certifi: 2024.6.2 - cffi: 1.16.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - coloredlogs: 15.0.1 - contourpy: 1.2.1 - coverage: 5.3.1 - cryptography: 42.0.8 - cssselect2: 0.7.0 - cycler: 0.12.1 - data-gradients: 0.3.2 - deprecated: 1.2.14 - docutils: 0.17.1 - einops: 0.3.2 - exceptiongroup: 1.2.1 - filelock: 3.15.4 - flatbuffers: 24.3.25 - fonttools: 4.53.0 - frozenlist: 1.4.1 - fsspec: 2024.6.1 - future: 1.0.0 - grpcio: 1.64.1 - html5lib: 1.1 - huggingface-hub: 0.23.4 - humanfriendly: 10.0 - hydra-core: 1.3.2 - idna: 3.7 - imagededup: 0.3.1 - imageio: 2.34.2 - imagesize: 1.4.1 - iniconfig: 2.0.0 - jinja2: 3.1.4 - jmespath: 1.0.1 - joblib: 1.4.2 - json-tricks: 3.16.1 - jsonschema: 4.22.0 - jsonschema-specifications: 2023.12.1 - kiwisolver: 1.4.5 - lazy-loader: 0.4 - lightly: 1.5.8 - lightly-train: 0.1.0 - lightly-utils: 0.0.2 - lightning-utilities: 0.11.3.post0 - lxml: 5.2.2 - markdown: 3.6 - markdown-it-py: 3.0.0 - markupsafe: 2.1.5 - matplotlib: 3.9.0 - mdurl: 0.1.2 - mpmath: 1.3.0 - multidict: 6.0.5 - mypy: 1.10.1 - mypy-extensions: 1.0.0 - networkx: 3.3 - numpy: 1.23.0 - omegaconf: 2.3.0 - onnx: 1.15.0 - onnxruntime: 1.15.0 - onnxsim: 0.4.36 - opencv-python: 4.10.0.84 - opencv-python-headless: 4.10.0.84 - oscrypto: 1.3.0 - packaging: 24.1 - pandas: 2.2.2 - pillow: 10.4.0 - pip: 24.1.1 - pip-tools: 7.4.1 - platformdirs: 4.2.2 - pluggy: 1.5.0 - protobuf: 3.20.3 - psutil: 6.0.0 - pycparser: 2.22 - pydantic: 1.10.17 - pydeprecate: 0.3.2 - pygments: 2.18.0 - pyhanko: 0.25.0 - pyhanko-certvalidator: 0.26.3 - pyparsing: 3.1.2 - pypdf: 4.2.0 - pypng: 0.20220715.0 - pyproject-hooks: 1.1.0 - pytest: 8.2.2 - pytest-mock: 3.14.0 - python-bidi: 0.4.2 - python-dateutil: 2.9.0.post0 - pytorch-lightning: 2.3.1 - pytz: 2024.1 - pywavelets: 1.6.0 - pyyaml: 6.0.1 - qrcode: 7.4.2 - qudida: 0.0.4 - rapidfuzz: 3.9.3 - referencing: 0.35.1 - reportlab: 3.6.13 - requests: 2.32.3 - rich: 13.7.1 - rpds-py: 0.18.1 - ruff: 0.5.0 - s3transfer: 0.10.2 - safetensors: 0.4.3 - scikit-image: 0.24.0 - scikit-learn: 1.5.0 - scipy: 1.13.1 - seaborn: 0.13.2 - selftrain: 0.1.0 - setuptools: 70.2.0 - six: 1.16.0 - snowballstemmer: 2.2.0 - sphinx: 4.0.3 - sphinx-rtd-theme: 1.3.0 - sphinxcontrib-applehelp: 1.0.8 - sphinxcontrib-devhelp: 1.0.6 - sphinxcontrib-htmlhelp: 2.0.5 - sphinxcontrib-jquery: 4.1 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-qthelp: 1.0.7 - sphinxcontrib-serializinghtml: 1.1.10 - stringcase: 1.2.0 - super-gradients: 3.7.1 - svglib: 1.5.1 - sympy: 1.12.1 - tensorboard: 2.17.0 - tensorboard-data-server: 0.7.2 - termcolor: 1.1.0 - threadpoolctl: 3.5.0 - tifffile: 2024.6.18 - timm: 1.0.7 - tinycss2: 1.3.0 - tomli: 2.0.1 - torch: 2.3.1 - torchmetrics: 0.8.0 - torchvision: 0.18.1 - tqdm: 4.66.4 - treelib: 1.6.1 - typing-extensions: 4.12.2 - tzdata: 2024.1 - tzlocal: 5.2 - uritools: 4.0.3 - urllib3: 2.2.2 - webencodings: 0.5.1 - werkzeug: 3.0.3 - wheel: 0.43.0 - wrapt: 1.16.0 - xhtml2pdf: 0.2.11 - yarl: 1.9.4 * System: - OS: Darwin - architecture: - 64bit - - processor: arm - python: 3.10.8 - release: 23.5.0 - version: Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000 ### More info _No response_
awaelchli commented 3 months ago

Hey @MalteEbner The optimizer should be inside the dict in that example. Thank you for noticing it. Would you like to send a PR with this quick fix? Would be much appreciated 😃

MalteEbner commented 3 months ago

Hey @MalteEbner The optimizer should be inside the dict in that example. Thank you for noticing it. Would you like to send a PR with this quick fix? Would be much appreciated 😃

Unfortunately, the problem is not that the example is wrong. This is just a symptom of an underlying problem. The problem is that the LightningModule.configure_optimizers() -> OptimizerLRScheduler and its usage don't fit.

Toe see why, have a look at where configure_optimizers() is used, see the source code here: The output of it is stored under the variable name optim_conf and then passed to the function

def _configure_optimizers(
    optim_conf: Union[Dict[str, Any], List, Optimizer, Tuple],
)

However, these definition type OptimizerLRScheduler and the usage type Union[Dict[str, Any], List, Optimizer, Tuple]don't align. To fix this, more is needed

  1. Change it to optim_conf: OptimizerLRScheduler, so that the usage of configure_optimizers() has the same type as its definition.
  2. Redefine the OptimizerLRScheduler such that it fits the supported types in _configure_optimizers.
awaelchli commented 3 months ago

I'm not seeing any mypy errors regarding this. Which version did you use and what was the command you ran? The version we test with you can see here: https://github.com/Lightning-AI/pytorch-lightning/blob/master/requirements/typing.txt

We typically bump it together when an new torch version comes out (which is soon again). Maybe your issue will show up, but I'm not seeing it locally.

Yes sure, the internal _configure_optimizers uses a bit more generic typing. Feel free to update it to be more specific 👍 . The return type of LightningModule.configure_optimizers() should not be changed, this looks all good to me still.