Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.03k stars 3.36k forks source link

MLFlow logger with remote tracking fails with CLI #16310

Open Benjamin-Etheredge opened 1 year ago

Benjamin-Etheredge commented 1 year ago

Bug description

Running with the LightningCLI, MLflow logger, and MLFLOW_TRACKING_URI environment variable set causes an assertion failure with logging. I think using a remote tracking server causes no local log files to be created which the CLI doesn't like.

I suspect it's a similar issue to #12748.

How to reproduce the bug

from pytorch_lightning.cli import LightningCLI
from helpers import BoringModel, BoringDataModule

cli = LightningCLI(
    BoringModel, 
    BoringDataModule, 
    trainer_defaults=dict(
        max_epochs=1,
        logger="pytorch_lightning.loggers.MLFlowLogger"
    )
)
$ mlflow server
...
$ MLFLOW_TRACKING_URI=http://localhost:5000 python main.py fit

Error messages and logs

/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/configuration_validator.py:106: UserWarning: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
  rank_zero_warn("You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.")
Traceback (most recent call last):
  File "/workspaces/mlflow_log_error/main.py", line 4, in <module>
    cli = LightningCLI(
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 354, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 665, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit
    call._call_and_handle_interrupt(
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1037, in _run
    self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1285, in _call_setup_hook
    self._call_callback_hooks("setup", stage=fn)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1380, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/home/vscode/.local/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 216, in setup
    assert log_dir is not None
AssertionError

Environment

* CUDA:
        - GPU:               None
        - available:         False
        - version:           11.7
* Lightning:
        - lightning:         1.9.0rc0
        - lightning-cloud:   0.5.16
        - lightning-utilities: 0.5.0
        - pytorch-lightning: 1.8.6
        - torch:             1.13.1
        - torchmetrics:      0.11.0
        - torchvision:       0.14.1
* Packages:
        - aiohttp:           3.8.3
        - aiosignal:         1.3.1
        - alembic:           1.9.1
        - antlr4-python3-runtime: 4.9.3
        - anyio:             3.6.2
        - argon2-cffi:       21.3.0
        - argon2-cffi-bindings: 21.2.0
        - arrow:             1.2.3
        - asttokens:         2.2.1
        - async-timeout:     4.0.2
        - attrs:             22.2.0
        - babel:             2.11.0
        - backcall:          0.2.0
        - beautifulsoup4:    4.11.1
        - bleach:            5.0.1
        - blessed:           1.19.1
        - build:             0.9.0
        - certifi:           2022.12.7
        - cffi:              1.15.1
        - charset-normalizer: 2.1.1
        - click:             8.1.3
        - cloudpickle:       2.2.0
        - comm:              0.1.2
        - commonmark:        0.9.1
        - contourpy:         1.0.6
        - croniter:          1.3.8
        - cycler:            0.11.0
        - databricks-cli:    0.17.4
        - dateutils:         0.6.12
        - debugpy:           1.6.5
        - decorator:         5.1.1
        - deepdiff:          6.2.3
        - defusedxml:        0.7.1
        - dnspython:         2.2.1
        - docker:            6.0.1
        - docstring-parser:  0.15
        - email-validator:   1.3.0
        - entrypoints:       0.4
        - executing:         1.2.0
        - fastapi:           0.88.0
        - fastjsonschema:    2.16.2
        - flask:             2.2.2
        - fonttools:         4.38.0
        - fqdn:              1.5.1
        - frozenlist:        1.3.3
        - fsspec:            2022.11.0
        - gitdb:             4.0.10
        - gitpython:         3.1.30
        - greenlet:          2.0.1
        - gunicorn:          20.1.0
        - h11:               0.14.0
        - httpcore:          0.16.3
        - httptools:         0.5.0
        - httpx:             0.23.3
        - hydra-core:        1.3.1
        - idna:              3.4
        - importlib-metadata: 5.2.0
        - importlib-resources: 5.10.2
        - inquirer:          3.1.2
        - ipykernel:         6.20.1
        - ipython:           8.8.0
        - ipython-genutils:  0.2.0
        - isoduration:       20.11.0
        - itsdangerous:      2.1.2
        - jedi:              0.18.2
        - jinja2:            3.1.2
        - joblib:            1.2.0
        - json5:             0.9.11
        - jsonargparse:      4.19.0
        - jsonpointer:       2.3
        - jsonschema:        4.17.3
        - jupyter-client:    7.4.8
        - jupyter-core:      5.1.3
        - jupyter-events:    0.6.0
        - jupyter-server:    2.0.6
        - jupyter-server-terminals: 0.4.4
        - jupyterlab:        3.5.2
        - jupyterlab-pygments: 0.2.2
        - jupyterlab-server: 2.18.0
        - kiwisolver:        1.4.4
        - lightning:         1.9.0rc0
        - lightning-cloud:   0.5.16
        - lightning-utilities: 0.5.0
        - llvmlite:          0.39.1
        - mako:              1.2.4
        - markdown:          3.4.1
        - markupsafe:        2.1.1
        - matplotlib:        3.6.2
        - matplotlib-inline: 0.1.6
        - mistune:           2.0.4
        - mlflow:            2.1.1
        - multidict:         6.0.4
        - nbclassic:         0.4.8
        - nbclient:          0.7.2
        - nbconvert:         7.2.7
        - nbformat:          5.7.1
        - nest-asyncio:      1.5.6
        - notebook:          6.5.2
        - notebook-shim:     0.2.2
        - numba:             0.56.4
        - numpy:             1.23.5
        - nvidia-cublas-cu11: 11.10.3.66
        - nvidia-cuda-nvrtc-cu11: 11.7.99
        - nvidia-cuda-runtime-cu11: 11.7.99
        - nvidia-cudnn-cu11: 8.5.0.96
        - oauthlib:          3.2.2
        - omegaconf:         2.3.0
        - ordered-set:       4.1.0
        - orjson:            3.8.4
        - packaging:         21.3
        - pandas:            1.5.2
        - pandocfilters:     1.5.0
        - parso:             0.8.3
        - pep517:            0.13.0
        - pexpect:           4.8.0
        - pickleshare:       0.7.5
        - pillow:            9.4.0
        - pip:               22.3.1
        - pip-tools:         6.12.1
        - platformdirs:      2.6.2
        - prometheus-client: 0.15.0
        - prompt-toolkit:    3.0.36
        - protobuf:          3.20.1
        - psutil:            5.9.4
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - pyarrow:           10.0.1
        - pycparser:         2.21
        - pydantic:          1.10.4
        - pygments:          2.14.0
        - pyjwt:             2.6.0
        - pyparsing:         3.0.9
        - pyrsistent:        0.19.3
        - python-dateutil:   2.8.2
        - python-dotenv:     0.21.0
        - python-editor:     1.0.4
        - python-json-logger: 2.0.4
        - python-multipart:  0.0.5
        - pytorch-lightning: 1.8.6
        - pytz:              2022.7
        - pyyaml:            6.0
        - pyzmq:             24.0.1
        - querystring-parser: 1.2.4
        - readchar:          4.0.3
        - requests:          2.28.1
        - rfc3339-validator: 0.1.4
        - rfc3986:           1.5.0
        - rfc3986-validator: 0.1.1
        - rich:              13.0.1
        - scikit-learn:      1.2.0
        - scipy:             1.10.0
        - send2trash:        1.8.0
        - setuptools:        65.5.0
        - shap:              0.41.0
        - six:               1.16.0
        - slicer:            0.0.7
        - smmap:             5.0.0
        - sniffio:           1.3.0
        - soupsieve:         2.3.2.post1
        - sqlalchemy:        1.4.46
        - sqlparse:          0.4.3
        - stack-data:        0.6.2
        - starlette:         0.22.0
        - starsessions:      1.3.0
        - tabulate:          0.9.0
        - tensorboardx:      2.5.1
        - terminado:         0.17.1
        - threadpoolctl:     3.1.0
        - tinycss2:          1.2.1
        - tomli:             2.0.1
        - torch:             1.13.1
        - torchmetrics:      0.11.0
        - torchvision:       0.14.1
        - tornado:           6.2
        - tqdm:              4.64.1
        - traitlets:         5.8.1
        - typeshed-client:   2.1.0
        - typing-extensions: 4.4.0
        - ujson:             5.7.0
        - uri-template:      1.2.0
        - urllib3:           1.26.13
        - uvicorn:           0.20.0
        - uvloop:            0.17.0
        - watchfiles:        0.18.1
        - wcwidth:           0.2.5
        - webcolors:         1.12
        - webencodings:      0.5.1
        - websocket-client:  1.4.2
        - websockets:        10.4
        - werkzeug:          2.2.2
        - wheel:             0.38.4
        - yarl:              1.8.2
        - zipp:              3.11.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         
        - python:            3.10.9
        - version:           #1 SMP Wed Mar 2 00:30:59 UTC 2022

More info

No response

cc @carmocca @mauvilsa

Benjamin-Etheredge commented 1 year ago

A temporary workaround for this issue is to declare a TensorBoard logger ahead of the MLflow one. Like so,

cli = LightningCLI(
    BoringModel, 
    BoringDataModule, 
    trainer_defaults=dict(
        max_epochs=1,
        logger=[
            {
                "class_path": "pytorch_lightning.loggers.TensorBoardLogger", 
                "init_args": {
                    "save_dir": "tb_logs",
                }
            },
            "pytorch_lightning.loggers.MLFlowLogger"
        ],
    )
)
vincentwu0730 commented 1 year ago

@Benjamin-Etheredge Here is my workaround, which still leverage the goodness of CLI module and its yaml file.

cli = LightningCLI(
    LightningToneClassifier,
    ToneDataModule,
    run=False,
)

with open("lightning/trainer_config.yaml", "r") as f:
    config = yaml.safe_load(f)
config["trainer"]["logger"] = MLFlowLogger(
    experiment_name="xxxx",
    tracking_uri="xxxx",
    log_model=True,
)
train_dataloader, val_dataloader = prepare_fit_dataloader(cli)
trainer = Trainer(**config["trainer"])
trainer.logger.log_hyperparams(config)
trainer.fit(cli.model, train_dataloader, val_dataloader)
stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

goncalomcorreia commented 1 year ago

Hi @vincentwu0730 ,

thank you for your workaround! can you share what is prepare_fit_dataloader?

awaelchli commented 1 year ago

The issue surfaces through the usage in LightningCLI because it calls the log dir, but the origin of the problem as suspected by @Benjamin-Etheredge is because the save_dir from MLFlowLogger returns None in case tracking is not done locally:

https://github.com/Lightning-AI/lightning/blob/41f0425a8dbd54030c5b711f92340dc8dc41c173/src/lightning/pytorch/loggers/mlflow.py#L299-L301

Two possible solutions that come to my mind to address this:

  1. Return a default local directory instead of None so LightningCLI can save the config
  2. In the LightningCLI, if the value returned by the log dir is None, save the config to a different place (as if there is no logger).
mauvilsa commented 1 year ago

Two possible solutions that come to my mind to address this:

I can suggest another solution. Implement a custom save config class that saves the config in mlflow as an artifact, instead of saving the config locally. If logging remotely it makes sense to also save the config in the same place.

terbed commented 6 months ago

A realization of @mauvilsa idea:

from lightning.pytorch.cli import SaveConfigCallback
class MLFlowSaveConfigCallback(SaveConfigCallback):
    def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
        super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir=False)

    def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
        # Convert Namespace to dict
        config_dict = vars(self.config)

        # Log parameters to MLFlow
        pl_module.logger.log_hyperparams(config_dict)
def cli_compile_main():
    cli = LightningCLI(datamodule_class=PRDataModule, run=False, save_config_callback=MLFlowSaveConfigCallback)
    compiled_model = torch.compile(cli.model)
    cli.trainer.fit(compiled_model, datamodule=cli.datamodule)
    cli.trainer.test(datamodule=cli.datamodule)
adrianomartinelli commented 2 months ago

Slight modification of @terbed if you want to safe the file as yaml

from lightning.pytorch.cli import SaveConfigCallback
from lightning import Trainer, LightningModule
import tempfile

class MLFlowSaveConfigCallback(SaveConfigCallback):
    def __init__(self, parser, config, config_filename='config.yaml', overwrite=False, multifile=False):
        super().__init__(parser, config, config_filename, overwrite, multifile, save_to_log_dir=False)

    def save_config(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None:
        # convert namespace to dict
        config_dict = vars(self.config)

        if trainer.is_global_zero:
            with tempfile.TemporaryDirectory() as tmp_dir:
                config_path = Path(tmp_dir) / 'config.yaml'
                self.parser.save(
                    self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
                )
                trainer.logger.experiment.log_artifact(local_path=config_path,
                                                       run_id=trainer.logger.run_id)