Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.22k stars 3.38k forks source link

How to unit test LightningCLI without UserWarning? #18545

Open adamjstewart opened 1 year ago

adamjstewart commented 1 year ago

Bug description

I have a main.py file like:

from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.demos.boring_classes import DemoModel, BoringDataModule

def cli_main(args = None):
    cli = LightningCLI(DemoModel, BoringDataModule, args=args)

if __name__ == "__main__":
    cli_main()

This is the same as the tutorial script but adds an args parameter to support non-interactive use. I want to unit test this file, so I create a test.py file like:

from main import cli_main

def test_cli_main():
    args = ["fit", "--trainer.fast_dev_run", "true"]
    cli_main(args)

However, every time I test my code, I see a ton of warnings:

$ pytest test.py
...
test.py::test_cli_main
  /Users/Adam/spack/var/spack/environments/system/.spack-env/view/lib/python3.10/site-packages/lightning/pytorch/cli.py:484: UserWarning: LightningCLI's args parameter is intended to run from within Python like if it were from the command line. To prevent mistakes it is not recommended to provide both args and command line arguments, got: sys.argv[1:]=['test.py'], args=['fit', '--trainer.fast_dev_run', 'true'].
    rank_zero_warn(

test.py::test_cli_main
  /Users/Adam/spack/var/spack/environments/system/.spack-env/view/lib/python3.10/site-packages/lightning/fabric/utilities/seed.py:39: UserWarning: No seed found, seed set to 3925998095
    rank_zero_warn(f"No seed found, seed set to {seed}")

test.py::test_cli_main
  /Users/Adam/spack/var/spack/environments/system/.spack-env/view/lib/python3.10/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
    rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")

test.py::test_cli_main
  /Users/Adam/spack/var/spack/environments/system/.spack-env/view/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:432: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 10 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
    rank_zero_warn(

test.py::test_cli_main
  /Users/Adam/spack/var/spack/environments/system/.spack-env/view/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:280: PossibleUserWarning: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
    rank_zero_warn(

I can ignore these warnings, but I'm curious if they are all necessary. For example, the first warning will pretty much always be raised, won't it? Is it even possible to call cli_main in code without any sys.argv[1:]? The rest also seem worth ignoring. Would it be worth adding a parameter to Trainer that ignores all internal (created by Lightning) warnings? Or can we at least turn some of these off?

What version are you seeing the problem on?

v2.0

How to reproduce the bug

No response

Error messages and logs

No response

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - efficientnet-pytorch: 0.7.1 - lightning: 2.0.3 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.0 - pytorch-sphinx-theme: 0.0.24 - segmentation-models-pytorch: 0.3.3 - torch: 2.0.1 - torchmetrics: 0.11.4 - torchvision: 0.15.2 * Packages: - absl-py: 1.4.0 - aenum: 3.1.12 - affine: 2.1.0 - aiohttp: 3.8.1 - aiosignal: 1.2.0 - alabaster: 0.7.13 - antlr4-python3-runtime: 4.9.3 - anyio: 3.6.2 - appdirs: 1.4.4 - appnope: 0.1.3 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asttokens: 2.2.1 - astunparse: 1.6.3 - async-lru: 1.0.3 - async-timeout: 4.0.2 - attrs: 23.1.0 - babel: 2.12.1 - backcall: 0.2.0 - beautifulsoup4: 4.12.2 - beniget: 0.4.1 - black: 23.3.0 - bleach: 6.0.0 - blessed: 1.19.0 - bottleneck: 1.3.7 - build: 0.10.0 - cachetools: 5.2.0 - calver: 2022.6.26 - cartopy: 0.21.1 - certifi: 2023.5.7 - cffi: 1.15.1 - cftime: 1.0.3.4 - charset-normalizer: 2.0.12 - click: 8.1.3 - click-plugins: 1.1.1 - cligj: 0.7.2 - cmocean: 2.0 - colorama: 0.4.5 - comm: 0.1.3 - commonmark: 0.9.1 - contourpy: 1.0.7 - coverage: 7.2.6 - cppy: 1.2.1 - croniter: 1.3.8 - cycler: 0.11.0 - cython: 0.29.33 - dateutils: 0.6.12 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.0 - defusedxml: 0.7.1 - docstring-parser: 0.15 - docutils: 0.19 - editables: 0.3 - efficientnet-pytorch: 0.7.1 - einops: 0.6.1 - et-xmlfile: 1.0.1 - exceptiongroup: 1.1.1 - executing: 1.2.0 - fastapi: 0.88.0 - fastjsonschema: 2.16.3 - filelock: 3.12.0 - fiona: 1.9.4 - flake8: 6.0.0 - flit-core: 3.7.1 - fonttools: 4.39.4 - fqdn: 1.5.1 - frozenlist: 1.3.1 - fsspec: 2023.1.0 - gast: 0.5.3 - gdal: 3.7.0 - geocube: 0.3.2 - geopandas: 0.11.1 - gevent: 1.5.0 - google-auth: 1.6.3 - google-auth-oauthlib: 0.5.2 - greenlet: 2.0.2 - grpcio: 1.52.0 - h11: 0.13.0 - h5py: 3.8.0 - hatch-fancy-pypi-readme: 23.1.0 - hatch-jupyter-builder: 0.8.3 - hatch-vcs: 0.3.0 - hatchling: 1.17.0 - huggingface-hub: 0.14.1 - hydra-core: 1.3.1 - idna: 3.4 - imagesize: 1.4.1 - importlib-metadata: 6.6.0 - importlib-resources: 5.9.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - installer: 0.6.0 - ipykernel: 6.23.1 - ipython: 8.14.0 - ipywidgets: 8.0.2 - isoduration: 20.11.0 - isort: 5.10.1 - itsdangerous: 2.1.2 - jaraco.classes: 3.2.3 - jdcal: 1.3 - jedi: 0.18.1 - jinja2: 3.0.3 - joblib: 1.2.0 - json5: 0.9.14 - jsonargparse: 4.19.0 - jsonpointer: 2.0 - jsonschema: 4.17.3 - jupyter-client: 8.2.0 - jupyter-core: 5.3.0 - jupyter-events: 0.6.3 - jupyter-lsp: 2.2.0 - jupyter-server: 2.6.0 - jupyter-server-terminals: 0.4.4 - jupyterlab: 4.0.1 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.22.1 - jupyterlab-widgets: 3.0.3 - keyring: 23.13.1 - kiwisolver: 1.4.4 - kornia: 0.6.12 - laspy: 2.2.0 - lightly: 1.4.18 - lightly-utils: 0.0.2 - lightning: 2.0.3 - lightning-cloud: 0.5.36 - lightning-utilities: 0.8.0 - lxml: 4.9.1 - markdown: 3.4.1 - markupsafe: 2.1.1 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - meson: 1.1.0 - meson-python: 0.12.0 - mistune: 2.0.4 - more-itertools: 8.14.0 - mpmath: 1.2.1 - multidict: 6.0.2 - munch: 2.5.0 - mypy: 1.2.0 - mypy-extensions: 1.0.0 - nbclient: 0.6.7 - nbconvert: 7.0.0 - nbformat: 5.8.0 - nbmake: 1.4.1 - nbsphinx: 0.8.8 - nest-asyncio: 1.5.6 - netcdf4: 1.6.2 - networkx: 2.8.6 - notebook-shim: 0.2.2 - numexpr: 2.8.3 - numpy: 1.24.3 - oauthlib: 3.2.1 - odc-geo: 0.1.2 - omegaconf: 2.3.0 - openpyxl: 3.0.3 - ordered-set: 4.0.2 - overrides: 7.3.1 - packaging: 23.0 - pandas: 1.5.3 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathspec: 0.11.1 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.5.0 - pip: 23.0 - pkginfo: 1.8.3 - planetary-computer: 0.4.9 - platformdirs: 3.5.0 - pluggy: 1.0.0 - ply: 3.11 - poetry-core: 1.2.0 - pretrainedmodels: 0.7.4 - prometheus-client: 0.14.1 - prompt-toolkit: 3.0.31 - protobuf: 3.20.1 - psutil: 5.9.4 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pybind11: 2.10.1 - pycocotools: 2.0.6 - pycodestyle: 2.10.0 - pycparser: 2.21 - pydantic: 1.10.9 - pydocstyle: 6.2.1 - pyflakes: 3.0.1 - pygeos: 0.10 - pygments: 2.13.0 - pyjwt: 2.4.0 - pyparsing: 3.0.9 - pyproj: 3.5.0 - pyproject-hooks: 1.0.0 - pyproject-metadata: 0.7.1 - pyrsistent: 0.19.3 - pyshp: 2.1.0 - pystac: 1.4.0 - pystac-client: 0.5.1 - pytest: 7.2.1 - pytest-cov: 4.0.0 - python-dateutil: 2.8.2 - python-dotenv: 0.19.2 - python-editor: 1.0.4 - python-json-logger: 2.0.7 - python-multipart: 0.0.5 - pythran: 0.12.2 - pytorch-lightning: 2.0.0 - pytorch-sphinx-theme: 0.0.24 - pytz: 2022.2.1 - pyupgrade: 3.3.1 - pyyaml: 6.0 - pyzmq: 25.0.2 - radiant-mlhub: 0.5.1 - rarfile: 4.0 - rasterio: 1.3.7 - readchar: 4.0.5 - readme-renderer: 37.3 - requests: 2.28.2 - requests-oauthlib: 1.3.1 - requests-toolbelt: 0.9.1 - rfc3339-validator: 0.1.4 - rfc3986: 1.4.0 - rfc3986-validator: 0.1.1 - rich: 12.5.1 - rioxarray: 0.4.1.post0 - rsa: 4.9 - rtree: 1.0.1 - safetensors: 0.3.1 - scikit-learn: 1.2.2 - scipy: 1.10.1 - segmentation-models-pytorch: 0.3.3 - send2trash: 1.8.0 - setuptools: 63.0.0 - setuptools-scm: 7.0.5 - shapely: 1.8.4 - six: 1.16.0 - sniffio: 1.3.0 - snowballstemmer: 2.2.0 - snuggs: 1.4.1 - soupsieve: 2.3.2.post1 - sphinx: 5.3.0 - sphinx-design: 0.4.1 - sphinx-rtd-theme: 0.5.1 - sphinxcontrib-applehelp: 1.0.2 - sphinxcontrib-devhelp: 1.0.2 - sphinxcontrib-htmlhelp: 2.0.0 - sphinxcontrib-jsmath: 1.0.1 - sphinxcontrib-programoutput: 0.15 - sphinxcontrib-qthelp: 1.0.3 - sphinxcontrib-serializinghtml: 1.1.5 - stack-data: 0.5.0 - starlette: 0.22.0 - starsessions: 1.3.0 - sympy: 1.11.1 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - terminado: 0.15.0 - threadpoolctl: 3.1.0 - timm: 0.9.2 - tinycss2: 1.1.1 - tokenize-rt: 4.2.1 - tomli: 2.0.1 - torch: 2.0.1 - torchmetrics: 0.11.4 - torchvision: 0.15.2 - tornado: 6.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - trove-classifiers: 2023.3.9 - twine: 4.0.1 - typeshed-client: 2.1.0 - typing-extensions: 4.5.0 - uri-template: 1.2.0 - urllib3: 1.26.12 - uvicorn: 0.20.0 - vermin: 1.5.1 - versioneer: 0.28 - wcwidth: 0.2.5 - webcolors: 1.11.1 - webencodings: 0.5.1 - websocket-client: 1.5.1 - websockets: 10.4 - werkzeug: 2.3.4 - wheel: 0.37.1 - widgetsnbextension: 4.0.3 - xarray: 2022.3.0 - yarl: 1.8.1 - zipfile-deflate64: 0.2.0 - zipp: 3.8.1 * System: - OS: Darwin - architecture: - 64bit - - processor: arm - python: 3.10.10 - release: 22.6.0 - version: Darwin Kernel Version 22.6.0: Wed Jul 5 22:21:53 PDT 2023; root:xnu-8796.141.3~6/RELEASE_ARM64_T6020

More info

No response

cc @carmocca @mauvilsa

awaelchli commented 1 year ago

You can filter the warnings for example in a single unit test like so: https://docs.pytest.org/en/7.1.x/how-to/capture-warnings.html#pytest-mark-filterwarnings or asserting that they are raised like so: https://docs.pytest.org/en/7.1.x/how-to/capture-warnings.html#warns

It's ok to ignore these in your tests, but in general the warnings in Lightning are important notices for the user and set the expectations.

Here is what I noticed from your code and output:

adamjstewart commented 1 year ago

Args + sys.argv warning

This is the main one I think doesn't make sense. Happy to open a PR to remove this warning if you want.

No seed warning

The warning goes away if I set seed_everything_default to a specific integer. However, I don't think the default behavior (choose on the fly) should be a warning.

Val data loader warning

Agreed, this warning makes sense since fit attempts to run validate and can't.

num workers warning

We actually explicitly use 0 in our tests because it's faster on macOS/Windows where multiprocessing is very expensive. But I agree that some kind of warning suggesting a better # workers is prob a good default.

log every n steps warning

Agreed, this should be ignored when fast_dev_run. This is very similar to #13262.

awaelchli commented 1 year ago

@adamjstewart I think if you are testing your application this way, you should mock the sys.argv attribute to not contain arguments to avoid the warning because I think the warning is legit:

from main import cli_main

def test_cli_main():
    args = ["fit", "--trainer.fast_dev_run", "true"]
    with mock.patch("sys.argv", ["any.py"]):  # pretend no args were passed 
        cli_main(args)

or you go the other way and not pass args to the parser, and instead mock the contents of sys.argv directly from your test:

from main import cli_main

def test_cli_main():
    args = ["fit", "--trainer.fast_dev_run", "true"]
    with mock.patch("sys.argv", ["any.py"] + args):
        cli_main()  # pass no args!!

I recommend this because we do it like that in our tests but maybe @carmocca should confirm my answer.

mauvilsa commented 1 year ago

you should mock the sys.argv attribute to not contain arguments to avoid the warning because I think the warning is legit:

Yes.

Another tip. At some point the args parameter was added to the LightningCLI class. If you are using that in your tests, then instead of adding a with mock.patch to each test, in pytest you can do globally in conftest.py:

from unittest import mock

@pytest.fixture(scope="session", autouse=True)
def clear_sys_argv():
    with mock.patch("sys.argv", ["pytest"]):
        yield
adamjstewart commented 1 year ago

Mocking sys.argv does get around the warning, but I'm still not convinced that the warning is legit. In what circumstances would the args parameter be used when sys.argv[1:] is empty? It seems like it will almost always have at least one argument (the name of the script or test to run).

awaelchli commented 1 year ago

@adamjstewart My understanding is that this warning all it says is that the args way of passing params and the sys.args way of passing parameters are meant to be mutually exclusive. The warning is only raised when the condition if args is not None and len(sys.argv) > 1 is true. See here: https://github.com/Lightning-AI/lightning/blob/eb3b96d8bd3a0e42da81d4ed30c258af51571aac/src/lightning/pytorch/cli.py#L517

When running with pytest this is naturally problematic if of course we pass arguments to pytest itself, like the name of the test folder and other options etc. That's why we are recommending to mock it out. For these reasons I don't know if the warning can be removed, for this I'd like to request comment from @carmocca and/or @mauvilsa.

adamjstewart commented 1 year ago

Alright, I guess this makes sense for a script that is intended to be invoked like:

$ python main.py

where sys.argv[1:] will be empty. I would still be interested in a general toggle of all warnings for unit testing purposes, but I'll keep adding ignores to pyproject.toml for now.

mauvilsa commented 1 year ago

I would still be interested in a general toggle of all warnings for unit testing purposes, but I'll keep adding ignores to pyproject.toml for now.

I guess you could mock _warn in lightning_utilities so that it doesn't do anything: