Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
Apache License 2.0
28.35k stars 3.39k forks source link

Cannot call self.log in evaluation_hooks after using trainer.predict, even if using a new trainer object. #19101

Open bw4sz opened 11 months ago

bw4sz commented 11 months ago

Bug description

There has been alot of discussion around logging, trainer.predict, evaluation hooks and callbacks. I think I can boil this down to a reproducible example that will be useful for the community.

What has been discussed so far. (where I started the example below)

From these links, there is no clear guidance between using trainer.predict_step() and trainer.predict in why one can use logging and the other cannot. This is flirting with being a bug, but appears to be intended behavior from the comment below.

We are not inside a predict hook, we are inside a evaluation_hook. We did use trainer.predict, with all of its great functionality, to generate a set of predictions.

Expected behavior

I understand from the above issues as stated by @carmocca ( that we cannot overwrite the trainer state. Why doesn't this work with a new trainer?

What version are you seeing the problem on?


How to reproduce the bug

import os
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.demos.boring_classes import BoringModel

class MyModel(BoringModel):
    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:  # optional skip
        print("Start predicting!")
        for i, batch in enumerate(self.predict_dataloader()):
            batch = self.transfer_batch_to_device(batch, self.device, dataloader_idx=0)
            out = self.predict_step(batch, i)
            print(i, out)

        self.log("metric", 1.0)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run_predict_step():
    model = MyModel()
    trainer = Trainer(

class MyModel2(BoringModel):
    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:  # optional skip
        print("Start predicting!")
        dataloader = self.predict_dataloader()

        new_trainer = Trainer(

        new_trainer.predict(self, dataloaders=dataloader)

        self.log("metric", 1.0)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)

def run_trainer_predict():
    model = MyModel2()
    trainer = Trainer(

if __name__ == "__main__":
    # This works

    # This does not work

Error messages and logs

You are trying to `self.log()` but the loop's result collection is not registered yet. This is most likely because you are trying to log in a `predict` hook, but it doesn't support logging


(DeepForest) benweinstein@Bens-MacBook-Pro Downloads % python          
  <summary>Current environment</summary>

        - GPU:               None
        - available:         False
        - version:           None
* Lightning:
        - lightning-lite:    1.8.0.post1
        - lightning-utilities: 0.8.0
        - pytorch-lightning: 2.1.2
        - torch:             1.12.1
        - torchmetrics:      1.2.0
        - torchvision:       0.13.1
* Packages:
        - absl-py:           0.13.0
        - accessible-pygments: 0.0.4
        - affine:            2.3.0
        - aiohttp:           3.7.4.post0
        - alabaster:         0.7.12
        - albumentations:    1.1.0
        - async-timeout:     3.0.1
        - attrs:             21.2.0
        - babel:             2.9.1
        - beautifulsoup4:    4.12.2
        - bleach:            4.0.0
        - brotlipy:          0.7.0
        - bumpversion:       0.5.3
        - cached-property:   1.5.2
        - cachetools:        4.2.2
        - certifi:           2021.5.30
        - cffi:              1.14.6
        - chardet:           4.0.0
        - click:             7.1.2
        - click-plugins:     1.1.1
        - cligj:             0.7.2
        - cmarkgfm:          0.4.2
        - colorama:          0.4.4
        - commonmark:        0.9.1
        - cryptography:      3.4.7
        - cycler:            0.10.0
        - docutils:          0.18.1
        - execnet:           2.0.2
        - fiona:             1.8.20
        - fire:              0.4.0
        - fonttools:         4.25.0
        - fsspec:            2021.7.0
        - furo:              2023.9.10
        - future:            0.18.2
        - gdal:              3.3.1
        - geopandas:         0.9.0
        - google-auth:       1.34.0
        - google-auth-oauthlib: 0.4.5
        - gprof2dot:         2022.7.29
        - grpcio:            1.39.0
        - h5py:              3.3.0
        - idna:              2.10
        - imagecodecs:       2021.7.30
        - imageio:           2.9.0
        - imagesize:         1.4.1
        - importlib-metadata: 6.8.0
        - iniconfig:         1.1.1
        - jinja2:            3.0.1
        - joblib:            1.0.1
        - keyring:           23.0.1
        - kiwisolver:        1.3.1
        - lightning-lite:    1.8.0.post1
        - lightning-utilities: 0.8.0
        - mapclassify:       2.4.3
        - markdown:          3.3.4
        - markupsafe:        2.0.1
        - matplotlib:        3.4.2
        - more-itertools:    8.8.0
        - multidict:         5.1.0
        - munch:             2.5.0
        - munkres:           1.1.4
        - networkx:          2.6.2
        - numpy:             1.21.1
        - numpydoc:          1.1.0
        - oauthlib:          3.1.1
        - olefile:           0.46
        - opencv-python:
        - packaging:         21.0
        - pandas:            1.3.1
        - pillow:            9.2.0
        - pip:               21.2.2
        - pkginfo:           1.7.1
        - platformdirs:      3.11.0
        - pluggy:            0.13.1
        - progressbar2:      4.2.0
        - protobuf:          3.17.3
        - psutil:            5.8.0
        - py:                1.10.0
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pycocotools:       2.0.7
        - pycparser:         2.20
        - pydata-sphinx-theme: 0.14.1
        - pydeprecate:       0.3.1
        - pygments:          2.16.1
        - pyopenssl:         20.0.1
        - pyparsing:         2.4.7
        - pyproj:            3.1.0
        - pysocks:           1.7.1
        - pytest:            6.2.4
        - pytest-profiling:  1.7.0
        - pytest-xdist:      3.3.1
        - python-dateutil:   2.8.2
        - python-utils:      3.4.5
        - pytorch-lightning: 2.1.2
        - pytz:              2021.1
        - pywavelets:        1.1.1
        - pyyaml:            5.4.1
        - qudida:            0.0.4
        - rasterio:          1.2.6
        - readme-renderer:   24.0
        - recommonmark:      0.7.1
        - requests:          2.25.1
        - requests-oauthlib: 1.3.0
        - requests-toolbelt: 0.9.1
        - rfc3986:           1.4.0
        - rsa:               4.7.2
        - rtree:             0.9.7
        - scikit-image:      0.18.2
        - scikit-learn:      0.24.2
        - scipy:             1.7.0
        - setuptools:        59.5.0
        - shapely:           1.7.1
        - six:               1.16.0
        - slidingwindow:     0.0.14
        - snakeviz:          2.1.1
        - snowballstemmer:   2.1.0
        - snuggs:            1.4.7
        - soupsieve:         2.5
        - sphinx:            7.2.6
        - sphinx-basic-ng:   1.0.0b2
        - sphinx-markdown-tables: 0.0.15
        - sphinx-rtd-theme:  1.3.0
        - sphinxcontrib-applehelp: 1.0.2
        - sphinxcontrib-devhelp: 1.0.2
        - sphinxcontrib-htmlhelp: 2.0.0
        - sphinxcontrib-jquery: 4.1
        - sphinxcontrib-jsmath: 1.0.1
        - sphinxcontrib-qthelp: 1.0.3
        - sphinxcontrib-serializinghtml: 1.1.9
        - tensorboard:       2.10.0
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.0
        - termcolor:         2.1.0
        - threadpoolctl:     2.2.0
        - tifffile:          2021.7.30
        - toml:              0.10.2
        - tomli:             2.0.1
        - torch:             1.12.1
        - torchmetrics:      1.2.0
        - torchvision:       0.13.1
        - tornado:           6.1
        - tqdm:              4.62.0
        - twine:             0.0.0
        - typing-extensions: 4.3.0
        - urllib3:           1.26.6
        - webencodings:      0.5.1
        - werkzeug:          2.0.1
        - wheel:             0.36.2
        - xmltodict:         0.12.0
        - yapf:              0.40.2
        - yarl:              1.6.3
        - zipp:              3.5.0
* System:
        - OS:                Darwin
        - architecture:
                - 64bit
        - processor:         i386
        - python:            3.9.6
        - release:           23.1.0
        - version:           Darwin Kernel Version 23.1.0: Mon Oct  9 21:27:27 PDT 2023; root:xnu-10002.41.9~6/RELEASE_X86_64


More info

No response

tshu-w commented 1 month ago

restore _current_fx_name might work:

    def on_validation_epoch_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
        trainer_state = deepcopy(trainer.state)
        current_fx_name = pl_module._current_fx_name
        results = trainer.predict(pl_module, dataloader, return_predictions=True)
        trainer.state = trainer_state
        pl_module._current_fx_name = current_fx_name
        pl_module.log("val/ex", 0, prog_bar=True)
bw4sz commented 1 month ago

Thanks, can you give any insight into why that works, what's happening here that allows the trainer to be used inside the hook?

tshu-w commented 1 month ago

Thanks, can you give any insight into why that works, what's happening here that allows the trainer to be used inside the hook?

The on_validation_epoch_end of Callback accept trainer and pl_module as parameters. For LightningModule, you can just access trainer and _current_fx_name through self.trainer and self._current_fx_name such as:

    def on_validation_epoch_end(self):
        if self.trainer.sanity_checking:  # optional skip

        trainer_state = deepcopy(self.trainer.state)
        current_fx_name = self._current_fx_name
        print("Start predicting!")
        dataloader = self.predict_dataloader()
        self.trainer.predict(self, dataloaders=dataloader)

        self.trainer.state = trainer_state
        self._current_fx_name = current_fx_name

        self.log("metric", 1.0)

As for why that works, this is because we reset the _current_fx_name of LightningModule changed after prediction which cause self.log not working.