Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.23k stars 3.38k forks source link

Fit loop and validation loop tear down does not dereference passed dataloaders #18289

Open gov-ind opened 1 year ago

gov-ind commented 1 year ago

Bug description

When a dataloader is passed to pl.Trainer's fit, the trainer does not seem to clear all references to the passed dataloader. As a result, pickling the trainer will pickle the whole dataloader, which isn't ideal for large dataset.

Steps to reporoduce

Run the MRE below.

Expected Behaviour

The trainer's size after fitting should not increase by 1MB (the size of the dataset).

Actual Behaviour

The trainer's size after fitting increased by 1MB (the size of the dataset).

What version are you seeing the problem on?

master

How to reproduce the bug

from pickle import dumps

import lightning.pytorch as pl
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
import torch

model = BoringModel()

dataset = RandomDataset(32, 8192)
data_size = len(dumps(dataset))

trainer = pl.Trainer(fast_dev_run=True)
prev_size = len(dumps(trainer))

trainer.fit(model, train_dataloaders=torch.utils.data.DataLoader(dataset))
new_size = len(dumps(trainer))

size_diff_mb = (new_size - prev_size) // (1024 ** 2)
data_size_mb = data_size // (1024 ** 2)
print(size_diff_mb == data_size_mb)
assert size_diff_mb == 0

Error messages and logs

No response

Environment

Current environment * CUDA: - GPU: None - available: False - version: None * Lightning: - lightning: 2.0.6 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.6 - torch: 2.0.1+cpu - torch-model-archiver: 0.8.1 - torch-workflow-archiver: 0.2.9 - torchmetrics: 0.11.4 - torchserve: 0.8.1 - torchvision: 0.15.2a0 * Packages: - absl-py: 1.4.0 - anyio: 3.7.1 - arrow: 1.2.3 - astunparse: 1.6.3 - attrs: 23.1.0 - backoff: 2.2.1 - backports.cached-property: 1.0.2 - backports.functools-lru-cache: 1.6.5 - beautifulsoup4: 4.12.2 - blessed: 1.19.1 - blinker: 1.4 - brotlipy: 0.7.0 - build: 0.10.0 - cachecontrol: 0.12.11 - cachetools: 4.2.2 - captum: 0.6.0 - certifi: 2023.7.22 - cffi: 1.15.0 - charset-normalizer: 3.2.0 - cleo: 2.0.1 - click: 8.1.6 - colorama: 0.4.6 - contourpy: 1.1.0 - crashtest: 0.4.1 - croniter: 1.4.1 - cryptography: 41.0.2 - cycler: 0.11.0 - dateutils: 0.6.12 - deepdiff: 5.8.1 - distlib: 0.3.7 - docker: 6.1.3 - docstring-parser: 0.15 - dulwich: 0.21.3 - enum-compat: 0.0.3 - exceptiongroup: 1.1.2 - fastapi: 0.100.1 - filelock: 3.12.2 - flatbuffers: 2.0 - fonttools: 4.42.0 - fsspec: 2023.6.0 - gast: 0.4.0 - ghp-import: 2.1.0 - google-api-core: 2.11.1 - google-auth: 1.21.3 - google-auth-oauthlib: 0.5.2 - google-cloud-core: 2.3.3 - google-cloud-storage: 2.10.0 - google-crc32c: 1.5.0 - google-pasta: 0.2.0 - google-resumable-media: 2.5.0 - googleapis-common-protos: 1.59.1 - griffe: 0.32.3 - grpcio: 1.48.2 - h11: 0.14.0 - h5py: 3.7.0 - html5lib: 1.1 - idna: 3.4 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.0 - iniconfig: 2.0.0 - inquirer: 3.1.3 - installer: 0.7.0 - itsdangerous: 2.1.2 - jaraco.classes: 3.3.0 - jeepney: 0.8.0 - jinja2: 3.1.2 - joblib: 1.3.1 - jsonschema: 4.17.3 - keras: 2.12.0 - keras-pickle-wrapper: 1.0.5 - keras-preprocessing: 1.1.2 - keyring: 23.13.1 - kfp: 2.0.1 - kfp-pipeline-spec: 0.2.2 - kfp-server-api: 2.0.0 - kiwisolver: 1.4.4 - kubernetes: 26.1.0 - lightning: 2.0.6 - lightning-cloud: 0.5.37 - lightning-utilities: 0.9.0 - lockfile: 0.12.2 - markdown: 3.4.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.1 - material: 0.1 - materialx: 0.0.0.dev1 - matplotlib: 3.7.2 - mdurl: 0.1.0 - mergedeep: 1.3.4 - minio: 7.1.15 - mkdocs: 1.5.2 - mkdocs-autorefs: 0.5.0 - mkdocs-material: 9.1.21 - mkdocs-material-extensions: 1.1.1 - mkdocstrings: 0.22.0 - mkdocstrings-python: 1.3.0 - mlframework: 2.2.1 - more-itertools: 10.0.0 - mpmath: 1.3.0 - msgpack: 1.0.3 - networkx: 3.1 - numpy: 1.22.3 - oauthlib: 3.2.2 - opt-einsum: 3.3.0 - ordered-set: 4.1.0 - packaging: 23.1 - pandas: 2.0.3 - pathspec: 0.11.2 - pexpect: 4.8.0 - pillow: 9.4.0 - pip: 23.2.1 - pipdeptree: 2.12.0 - pkginfo: 1.9.6 - pkgutil-resolve-name: 1.3.10 - platformdirs: 3.10.0 - pluggy: 1.2.0 - poetry: 1.5.1 - poetry-core: 1.6.1 - poetry-plugin-export: 1.4.0 - protobuf: 3.20.3 - psutil: 5.9.0 - psycopg2: 2.9.6 - ptyprocess: 0.7.0 - pyasn1: 0.4.8 - pyasn1-modules: 0.2.8 - pycparser: 2.21 - pydantic: 1.10.8 - pygments: 2.15.1 - pyjwt: 2.8.0 - pymdown-extensions: 10.1 - pympler: 1.0.1 - pyopenssl: 23.2.0 - pyparsing: 3.0.9 - pyproject-hooks: 1.0.0 - pyrsistent: 0.18.0 - pysocks: 1.7.1 - pytest: 7.4.0 - python-dateutil: 2.8.2 - python-editor: 1.0.4 - python-multipart: 0.0.6 - pytorch-lightning: 2.0.6 - pytz: 2023.3 - pyyaml: 6.0.1 - pyyaml-env-tag: 0.1 - rapidfuzz: 2.13.7 - readchar: 4.0.5.dev0 - regex: 2023.8.8 - requests: 2.31.0 - requests-oauthlib: 1.3.0 - requests-toolbelt: 0.9.1 - rich: 13.5.1 - rsa: 4.7.2 - scikit-learn: 1.3.0 - scipy: 1.11.1 - secretstorage: 3.3.3 - setuptools: 68.0.0 - shellingham: 1.5.1 - six: 1.16.0 - sniffio: 1.3.0 - soupsieve: 2.3.2.post1 - starlette: 0.27.0 - starsessions: 1.3.0 - sympy: 1.12 - tabulate: 0.9.0 - tenacity: 8.2.2 - tensorboard: 2.12.1 - tensorboard-data-server: 0.7.0 - tensorboard-plugin-wit: 1.8.1 - tensorflow: 2.12.0 - tensorflow-estimator: 2.12.0 - termcolor: 2.1.0 - threadpoolctl: 3.2.0 - tomli: 2.0.1 - tomlkit: 0.12.1 - torch: 2.0.1+cpu - torch-model-archiver: 0.8.1 - torch-workflow-archiver: 0.2.9 - torchmetrics: 0.11.4 - torchserve: 0.8.1 - torchvision: 0.15.2a0 - tqdm: 4.65.0 - traitlets: 5.9.0 - trove-classifiers: 2023.7.6 - typeguard: 4.0.0 - typing-extensions: 4.7.1 - tzdata: 2023.3 - urllib3: 1.26.15 - uvicorn: 0.23.2 - virtualenv: 20.24.2 - watchdog: 3.0.0 - wcwidth: 0.2.6 - webencodings: 0.5.1 - websocket-client: 1.6.1 - websockets: 10.4 - werkzeug: 2.2.3 - wheel: 0.38.4 - wrapt: 1.14.1 - zipp: 3.16.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.10.12 - release: 5.4.0-1106-gcp - version: #115~18.04.1-Ubuntu SMP Mon May 22 20:46:39 UTC 202 3

More info

Possible Fix

Set trainer.fit_loop._data_source.instance = None and trainer.fit_loop._combined_loader = None in the loop's teardown. These two variables seem to be holding on to the dataloader's reference. I can create a PR that applies this fix.

Workaround

  1. Set the above variables to None manually just before pickling the trainer.
  2. Since this issue doesn't happen if the dataloader is passed by overriding train_dataloader, don't pass the dataloader to fit.

cc @justusschock @awaelchli @borda

awaelchli commented 1 year ago

Thanks @gov-ind for investigating. I think it makes sense to do this. You are welcome to send a PR, thanks for the help!

carmocca commented 1 year ago

I would strongly suggest that you don't pickle the Trainer object. This is a bad idea as you are pickling the precise code and imports, which might break with future changes.

Since dereferencing the dataloader after training finishes is a breaking change (the tests in your PR showed this), my suggestion is that we don't change anything, since you are still able to dereference manually. This is also your suggestion in point (1): https://github.com/Lightning-AI/lightning/pull/18293/files#r1307301711

However, you could add a test that demonstrates your use case so that it's considered in future changes.