Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.4k stars 3.39k forks source link

self.log when model is compiled does not accumulate the metrics #17929

Open TommasoBendinelli opened 1 year ago

TommasoBendinelli commented 1 year ago

Bug description

Hi,

self.log('val_loss', loss, prog_bar=True, on_epoch=True) does not accumulate the values when the model is compiled via

if cfg.compile:
        model = torch.compile(model)

For instance if I have three validation steps with loss [1,2,0], the val_loss at the end of the epoch will be 0 and not the mean of 1,2,0 (i.e. 1).

If the model is not compiled everything works fine

I am on Python 3.10.9

What version are you seeing the problem on?

v2.0

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

2023-06-26 15:51:58 (59.8 MB/s) - ‘collect_env_details.py’ saved [2759/2759]

Current environment * CUDA: - GPU: - NVIDIA GeForce RTX 4090 - NVIDIA GeForce RTX 4090 - available: True - version: 11.8 * Lightning: - efficientnet-pytorch: 0.7.1 - lightning: 2.0.4 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - pytorch-lightning: 2.0.4 - torch: 2.0.1+cu118 - torchaudio: 2.0.2+cu118 - torchmetrics: 0.11.4 - torchvision: 0.15.2+cu118 * Packages: - absl-py: 1.4.0 - addict: 2.4.0 - aioboto3: 11.2.0 - aiobotocore: 2.5.0 - aiohttp: 3.8.4 - aioitertools: 0.11.0 - aiosignal: 1.3.1 - albumentations: 1.3.1 - altair: 5.0.1 - ansi2html: 1.8.0 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.0 - arrow: 1.2.3 - asttokens: 2.2.1 - async-timeout: 4.0.2 - attrs: 23.1.0 - autobahn: 23.6.2 - automat: 22.10.0 - backcall: 0.2.0 - beautifulsoup4: 4.12.2 - blessed: 1.20.0 - blinker: 1.6.2 - boto3: 1.26.76 - botocore: 1.29.76 - cachetools: 5.3.1 - certifi: 2022.12.7 - cffi: 1.15.1 - charset-normalizer: 2.1.1 - click: 8.1.3 - cmake: 3.25.0 - comm: 0.1.3 - configargparse: 1.5.3 - constantly: 15.1.0 - contourpy: 1.1.0 - croniter: 1.3.15 - cryptography: 41.0.1 - cycler: 0.11.0 - dash: 2.11.0 - dash-core-components: 2.0.0 - dash-html-components: 2.0.0 - dash-table: 5.0.0 - dateutils: 0.6.12 - ddd-moco: 0.1.0 - debugpy: 1.6.7 - decorator: 5.1.1 - deepdiff: 6.3.0 - deeplake: 3.6.6 - dill: 0.3.6 - efficientnet-pytorch: 0.7.1 - entrypoints: 0.4 - exceptiongroup: 1.1.1 - executing: 1.2.0 - fastapi: 0.98.0 - fastjsonschema: 2.17.1 - filelock: 3.9.0 - flask: 2.2.5 - fonttools: 4.40.0 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - gitdb: 4.0.10 - gitpython: 3.1.31 - google-auth: 2.20.0 - google-auth-oauthlib: 1.0.0 - grpcio: 1.56.0 - h11: 0.14.0 - hkdf: 0.0.3 - humanize: 4.6.0 - humbug: 0.3.1 - hydra-core: 1.3.2 - hyperlink: 21.0.0 - idna: 3.4 - imageio: 2.31.1 - importlib-metadata: 6.7.0 - incremental: 22.10.0 - inquirer: 3.1.3 - ipykernel: 6.23.3 - ipython: 8.14.0 - ipywidgets: 8.0.6 - itsdangerous: 2.1.2 - jedi: 0.18.2 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.2.0 - jsonschema: 4.17.3 - jupyter-client: 8.3.0 - jupyter-core: 5.3.1 - jupyterlab-widgets: 3.0.7 - kiwisolver: 1.4.4 - lazy-loader: 0.2 - libdeeplake: 0.0.60 - lightning: 2.0.4 - lightning-cloud: 0.5.37 - lightning-utilities: 0.8.0 - line-profiler: 4.0.3 - lit: 15.0.7 - magic-wormhole: 0.12.0 - markdown: 3.4.3 - markdown-it-py: 3.0.0 - markupsafe: 2.1.2 - matplotlib: 3.7.1 - matplotlib-inline: 0.1.6 - mdurl: 0.1.2 - mpmath: 1.2.1 - multidict: 6.0.4 - multiprocess: 0.70.14 - nbformat: 5.7.0 - nest-asyncio: 1.5.6 - networkx: 3.0 - numcodecs: 0.11.0 - numpy: 1.24.1 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - open3d: 0.17.0 - opencv-python: 4.7.0.72 - opencv-python-headless: 4.7.0.72 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.2 - parso: 0.8.3 - pathos: 0.3.0 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 9.3.0 - pip: 22.3.1 - platformdirs: 3.8.0 - plotly: 5.15.0 - pox: 0.3.2 - ppft: 1.7.6.6 - prompt-toolkit: 3.0.38 - protobuf: 4.23.3 - psutil: 5.9.5 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 12.0.1 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pycparser: 2.21 - pydantic: 1.10.9 - pydeck: 0.8.1b0 - pygments: 2.15.1 - pyjwt: 2.7.0 - pympler: 1.0.1 - pynacl: 1.5.0 - pyopenssl: 23.2.0 - pyparsing: 3.1.0 - pyquaternion: 0.9.9 - pyrsistent: 0.19.3 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.4 - pytz: 2023.3 - pytz-deprecation-shim: 0.1.0.post0 - pywavelets: 1.4.1 - pyyaml: 6.0 - pyzmq: 25.1.0 - qudida: 0.0.4 - readchar: 4.0.5 - requests: 2.28.1 - requests-oauthlib: 1.3.1 - retrying: 1.3.4 - rich: 13.4.2 - rsa: 4.9 - s3transfer: 0.6.1 - scikit-image: 0.21.0 - scikit-learn: 1.2.2 - scipy: 1.11.0 - service-identity: 23.1.0 - setuptools: 65.5.0 - six: 1.16.0 - smmap: 5.0.0 - sniffio: 1.3.0 - soupsieve: 2.4.1 - spake2: 0.8 - stack-data: 0.6.2 - starlette: 0.27.0 - starsessions: 1.3.0 - streamlit: 1.23.1 - sympy: 1.11.1 - tenacity: 8.2.2 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.1 - threadpoolctl: 3.1.0 - tifffile: 2023.4.12 - timeout-decorator: 0.5.0 - toml: 0.10.2 - toolz: 0.12.0 - torch: 2.0.1+cu118 - torchaudio: 2.0.2+cu118 - torchmetrics: 0.11.4 - torchvision: 0.15.2+cu118 - tornado: 6.3.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - triton: 2.0.0 - twisted: 22.10.0 - txaio: 23.1.1 - txtorcon: 23.5.0 - typing-extensions: 4.4.0 - tzdata: 2023.3 - tzlocal: 4.3.1 - urllib3: 1.26.13 - uvicorn: 0.22.0 - validators: 0.20.0 - watchdog: 3.0.0 - wcwidth: 0.2.6 - websocket-client: 1.6.1 - websockets: 11.0.3 - werkzeug: 2.2.3 - wheel: 0.40.0 - widgetsnbextension: 4.0.7 - wrapt: 1.15.0 - yarl: 1.9.2 - zipp: 3.15.0 - zope.interface: 6.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.9 - release: 5.15.0-73-generic - version: #80-Ubuntu SMP Mon May 15 15:18:26 UTC 2023 ### More info _No response_ cc @carmocca @Blaizzy
TommasoBendinelli commented 1 year ago

Same buggy behaviour for the self.log in the training step

TommasoBendinelli commented 1 year ago

Any news about this?