Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.35k stars 3.39k forks source link

Cannot call self.log in evaluation_hooks after using trainer.predict, even if using a new trainer object. #19101

Open bw4sz opened 11 months ago

bw4sz commented 11 months ago

Bug description

There has been alot of discussion around logging, trainer.predict, evaluation hooks and callbacks. I think I can boil this down to a reproducible example that will be useful for the community.

What has been discussed so far.

https://github.com/Lightning-AI/pytorch-lightning/issues/10365 https://github.com/Lightning-AI/pytorch-lightning/discussions/16258 (where I started the example below) https://github.com/Lightning-AI/pytorch-lightning/issues/16822 https://github.com/Lightning-AI/pytorch-lightning/issues/7333

From these links, there is no clear guidance between using trainer.predict_step() and trainer.predict in why one can use logging and the other cannot. This is flirting with being a bug, but appears to be intended behavior from the comment below.

We are not inside a predict hook, we are inside a evaluation_hook. We did use trainer.predict, with all of its great functionality, to generate a set of predictions.

Expected behavior

I understand from the above issues as stated by @carmocca (https://github.com/Lightning-AI/pytorch-lightning/issues/7333#issuecomment-1027107255) that we cannot overwrite the trainer state. Why doesn't this work with a new trainer?

What version are you seeing the problem on?

v2.1

How to reproduce the bug

import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel

class MyModel(BoringModel):
    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:  # optional skip
            return
        print("Start predicting!")
        for i, batch in enumerate(self.predict_dataloader()):
            batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
            out = self.predict_step(batch, i)
            print(i, out)

        self.log("metric", 1.0)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run_predict_step():
    model = MyModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        accelerator="auto",
        limit_train_batches=1,
        limit_val_batches=1,
        fast_dev_run=True,
        max_epochs=1,
        enable_model_summary=False,
    )
    trainer.fit(model)

class MyModel2(BoringModel):
    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:  # optional skip
            return
        print("Start predicting!")
        dataloader = self.predict_dataloader()

        new_trainer = Trainer(
        default_root_dir=os.getcwd(),
        accelerator="auto",
        limit_train_batches=1,
        limit_val_batches=1,
        fast_dev_run=True,
        max_epochs=1,
        enable_model_summary=False)

        new_trainer.predict(self, dataloaders=dataloader)

        self.log("metric", 1.0)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run_trainer_predict():
    model = MyModel2()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        accelerator="auto",
        limit_train_batches=1,
        limit_val_batches=1,
        max_epochs=1,
        fast_dev_run=True,
        enable_model_summary=False,
    )
    trainer.fit(model)

if __name__ == "__main__":
    # This works
    run_predict_step()

    # This does not work
    run_trainer_predict()

Error messages and logs

You are trying to `self.log()` but the loop's result collection is not registered yet. This is most likely because you are trying to log in a `predict` hook, but it doesn't support logging

Environment

(DeepForest) benweinstein@Bens-MacBook-Pro Downloads % python collect_env_details.py          
<details>
  <summary>Current environment</summary>

* CUDA:
        - GPU:               None
        - available:         False
        - version:           None
* Lightning:
        - lightning-lite:    1.8.0.post1
        - lightning-utilities: 0.8.0
        - pytorch-lightning: 2.1.2
        - torch:             1.12.1
        - torchmetrics:      1.2.0
        - torchvision:       0.13.1
* Packages:
        - absl-py:           0.13.0
        - accessible-pygments: 0.0.4
        - affine:            2.3.0
        - aiohttp:           3.7.4.post0
        - alabaster:         0.7.12
        - albumentations:    1.1.0
        - async-timeout:     3.0.1
        - attrs:             21.2.0
        - babel:             2.9.1
        - beautifulsoup4:    4.12.2
        - bleach:            4.0.0
        - brotlipy:          0.7.0
        - bumpversion:       0.5.3
        - cached-property:   1.5.2
        - cachetools:        4.2.2
        - certifi:           2021.5.30
        - cffi:              1.14.6
        - chardet:           4.0.0
        - click:             7.1.2
        - click-plugins:     1.1.1
        - cligj:             0.7.2
        - cmarkgfm:          0.4.2
        - colorama:          0.4.4
        - commonmark:        0.9.1
        - cryptography:      3.4.7
        - cycler:            0.10.0
        - docutils:          0.18.1
        - execnet:           2.0.2
        - fiona:             1.8.20
        - fire:              0.4.0
        - fonttools:         4.25.0
        - fsspec:            2021.7.0
        - furo:              2023.9.10
        - future:            0.18.2
        - gdal:              3.3.1
        - geopandas:         0.9.0
        - google-auth:       1.34.0
        - google-auth-oauthlib: 0.4.5
        - gprof2dot:         2022.7.29
        - grpcio:            1.39.0
        - h5py:              3.3.0
        - idna:              2.10
        - imagecodecs:       2021.7.30
        - imageio:           2.9.0
        - imagesize:         1.4.1
        - importlib-metadata: 6.8.0
        - iniconfig:         1.1.1
        - jinja2:            3.0.1
        - joblib:            1.0.1
        - keyring:           23.0.1
        - kiwisolver:        1.3.1
        - lightning-lite:    1.8.0.post1
        - lightning-utilities: 0.8.0
        - mapclassify:       2.4.3
        - markdown:          3.3.4
        - markupsafe:        2.0.1
        - matplotlib:        3.4.2
        - more-itertools:    8.8.0
        - multidict:         5.1.0
        - munch:             2.5.0
        - munkres:           1.1.4
        - networkx:          2.6.2
        - numpy:             1.21.1
        - numpydoc:          1.1.0
        - oauthlib:          3.1.1
        - olefile:           0.46
        - opencv-python:     4.6.0.66
        - packaging:         21.0
        - pandas:            1.3.1
        - pillow:            9.2.0
        - pip:               21.2.2
        - pkginfo:           1.7.1
        - platformdirs:      3.11.0
        - pluggy:            0.13.1
        - progressbar2:      4.2.0
        - protobuf:          3.17.3
        - psutil:            5.8.0
        - py:                1.10.0
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pycocotools:       2.0.7
        - pycparser:         2.20
        - pydata-sphinx-theme: 0.14.1
        - pydeprecate:       0.3.1
        - pygments:          2.16.1
        - pyopenssl:         20.0.1
        - pyparsing:         2.4.7
        - pyproj:            3.1.0
        - pysocks:           1.7.1
        - pytest:            6.2.4
        - pytest-profiling:  1.7.0
        - pytest-xdist:      3.3.1
        - python-dateutil:   2.8.2
        - python-utils:      3.4.5
        - pytorch-lightning: 2.1.2
        - pytz:              2021.1
        - pywavelets:        1.1.1
        - pyyaml:            5.4.1
        - qudida:            0.0.4
        - rasterio:          1.2.6
        - readme-renderer:   24.0
        - recommonmark:      0.7.1
        - requests:          2.25.1
        - requests-oauthlib: 1.3.0
        - requests-toolbelt: 0.9.1
        - rfc3986:           1.4.0
        - rsa:               4.7.2
        - rtree:             0.9.7
        - scikit-image:      0.18.2
        - scikit-learn:      0.24.2
        - scipy:             1.7.0
        - setuptools:        59.5.0
        - shapely:           1.7.1
        - six:               1.16.0
        - slidingwindow:     0.0.14
        - snakeviz:          2.1.1
        - snowballstemmer:   2.1.0
        - snuggs:            1.4.7
        - soupsieve:         2.5
        - sphinx:            7.2.6
        - sphinx-basic-ng:   1.0.0b2
        - sphinx-markdown-tables: 0.0.15
        - sphinx-rtd-theme:  1.3.0
        - sphinxcontrib-applehelp: 1.0.2
        - sphinxcontrib-devhelp: 1.0.2
        - sphinxcontrib-htmlhelp: 2.0.0
        - sphinxcontrib-jquery: 4.1
        - sphinxcontrib-jsmath: 1.0.1
        - sphinxcontrib-qthelp: 1.0.3
        - sphinxcontrib-serializinghtml: 1.1.9
        - tensorboard:       2.10.0
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.0
        - termcolor:         2.1.0
        - threadpoolctl:     2.2.0
        - tifffile:          2021.7.30
        - toml:              0.10.2
        - tomli:             2.0.1
        - torch:             1.12.1
        - torchmetrics:      1.2.0
        - torchvision:       0.13.1
        - tornado:           6.1
        - tqdm:              4.62.0
        - twine:             0.0.0
        - typing-extensions: 4.3.0
        - urllib3:           1.26.6
        - webencodings:      0.5.1
        - werkzeug:          2.0.1
        - wheel:             0.36.2
        - xmltodict:         0.12.0
        - yapf:              0.40.2
        - yarl:              1.6.3
        - zipp:              3.5.0
* System:
        - OS:                Darwin
        - architecture:
                - 64bit
                - 
        - processor:         i386
        - python:            3.9.6
        - release:           23.1.0
        - version:           Darwin Kernel Version 23.1.0: Mon Oct  9 21:27:27 PDT 2023; root:xnu-10002.41.9~6/RELEASE_X86_64

</details>

More info

No response

tshu-w commented 1 month ago

restore _current_fx_name might work:

    def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
        trainer_state = deepcopy(trainer.state)
        current_fx_name = pl_module._current_fx_name
        results = trainer.predict(pl_module, dataloader, return_predictions=True)
        trainer.state = trainer_state
        pl_module._current_fx_name = current_fx_name
        pl_module.log("val/ex", 0, prog_bar=True)
bw4sz commented 1 month ago

Thanks, can you give any insight into why that works, what's happening here that allows the trainer to be used inside the hook?

tshu-w commented 1 month ago

Thanks, can you give any insight into why that works, what's happening here that allows the trainer to be used inside the hook?

The on_validation_epoch_end of Callback accept trainer and pl_module as parameters. For LightningModule, you can just access trainer and _current_fx_name through self.trainer and self._current_fx_name such as:

    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:  # optional skip
            return

        trainer_state = deepcopy(self.trainer.state)
        current_fx_name = self._current_fx_name
        print("Start predicting!")
        dataloader = self.predict_dataloader()
        self.trainer.predict(self, dataloaders=dataloader)

        self.trainer.state = trainer_state
        self._current_fx_name = current_fx_name

        self.log("metric", 1.0)

As for why that works, this is because we reset the _current_fx_name of LightningModule changed after prediction which cause self.log not working.