Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.2k stars 3.37k forks source link

Error loading a saved model to run inference (using ddp_notebook strategy) #19869

Open carlos-havier opened 5 months ago

carlos-havier commented 5 months ago

Bug description

Lightning throws an error when using a saved model to run inference, while using the ddp_notebook strategy.

In this case, it throws the error: "RuntimeError: Lightning can't create new processes if CUDA is already initialized. Did you manually call torch.cuda.* functions, have moved the model to the device, or allocated memory on the GPU any other way? Please remove any such calls, or change the selected strategy. You will have to restart the Python kernel."

I submit a minimum working example to reproduce the error.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

https://colab.research.google.com/drive/1sxHACc95h-LcR48t3NYUfteLxaI4EmHB?usp=sharing

Error messages and logs

RuntimeError: Lightning can't create new processes if CUDA is already initialized. Did you manually call torch.cuda.* functions, have moved the model to the device, or allocated memory on the GPU any other way? Please remove any such calls, or change the selected strategy. You will have to restart the Python kernel

Environment

Current environment * CUDA: - GPU: - NVIDIA GeForce RTX 3080 Laptop GPU - available: True - version: 11.8 * Lightning: - lightning: 2.2.1 - lightning-utilities: 0.10.1 - pytorch-lightning: 2.1.3 - torch: 2.2.1 - torchaudio: 2.2.1 - torchmetrics: 1.2.1 - torchvision: 0.17.1 * Packages: - aiohttp: 3.9.3 - aiosignal: 1.3.1 - alembic: 1.13.1 - anyio: 4.3.0 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - async-lru: 2.0.4 - attrs: 23.2.0 - babel: 2.14.0 - beautifulsoup4: 4.12.3 - bleach: 6.1.0 - brotli: 1.0.9 - cached-property: 1.5.2 - certifi: 2024.2.2 - cffi: 1.16.0 - charset-normalizer: 2.0.4 - colorama: 0.4.6 - colorlog: 6.8.2 - comm: 0.2.1 - contourpy: 1.2.0 - cycler: 0.12.1 - datasets: 2.18.0 - debugpy: 1.8.1 - decorator: 5.1.1 - defusedxml: 0.7.1 - dill: 0.3.8 - entrypoints: 0.4 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - fonttools: 4.49.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2024.2.0 - ftfy: 6.1.3 - gmpy2: 2.1.2 - greenlet: 3.0.3 - h11: 0.14.0 - h2: 4.1.0 - hpack: 4.0.0 - httpcore: 1.0.4 - httpx: 0.27.0 - huggingface-hub: 0.21.3 - hyperframe: 6.0.1 - idna: 3.4 - importlib-metadata: 7.0.1 - importlib-resources: 6.1.2 - ipykernel: 6.29.3 - ipython: 8.22.2 - ipywidgets: 8.1.2 - isoduration: 20.11.0 - jedi: 0.19.1 - jinja2: 3.1.3 - joblib: 1.3.2 - json5: 0.9.22 - jsonpointer: 2.4 - jsonschema: 4.21.1 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.0 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.4 - jupyter-server: 2.13.0 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.4 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.4 - jupyterlab-widgets: 3.0.10 - kiwisolver: 1.4.5 - lightning: 2.2.1 - lightning-utilities: 0.10.1 - mako: 1.3.2 - markupsafe: 2.1.3 - matplotlib: 3.8.3 - matplotlib-inline: 0.1.6 - mistune: 3.0.2 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - mpmath: 1.3.0 - multidict: 6.0.5 - multiprocess: 0.70.16 - munkres: 1.1.4 - nbclient: 0.8.0 - nbconvert: 7.16.2 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.1 - notebook-shim: 0.2.4 - numpy: 1.26.4 - optuna: 3.5.0 - overrides: 7.7.0 - p-tqdm: 1.4.0 - packaging: 23.2 - pandas: 2.2.1 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathos: 0.3.2 - patsy: 0.5.6 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 10.2.0 - pip: 23.3.1 - pkgutil-resolve-name: 1.3.10 - platformdirs: 4.2.0 - ply: 3.11 - pox: 0.3.4 - ppft: 1.7.6.8 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.42 - psutil: 5.9.8 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 15.0.0 - pyarrow-hotfix: 0.6 - pycparser: 2.21 - pygments: 2.17.2 - pyparsing: 3.1.1 - pyqt5: 5.15.10 - pyqt5-sip: 12.13.0 - pysocks: 1.7.1 - python-dateutil: 2.9.0 - python-json-logger: 2.0.7 - pytorch-lightning: 2.1.3 - pytz: 2024.1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - referencing: 0.33.0 - regex: 2023.12.25 - requests: 2.31.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rpds-py: 0.18.0 - safetensors: 0.4.2 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - seaborn: 0.13.2 - send2trash: 1.8.2 - setuptools: 68.2.2 - sip: 6.7.12 - six: 1.16.0 - sniffio: 1.3.1 - soupsieve: 2.5 - sqlalchemy: 2.0.28 - stack-data: 0.6.2 - statsmodels: 0.14.1 - sympy: 1.12 - terminado: 0.18.0 - threadpoolctl: 3.3.0 - timm: 0.9.16 - tinycss2: 1.2.1 - tokenizers: 0.15.2 - tomli: 2.0.1 - torch: 2.2.1 - torchaudio: 2.2.1 - torchmetrics: 1.2.1 - torchvision: 0.17.1 - tornado: 6.4 - tqdm: 4.66.2 - traitlets: 5.14.1 - transformers: 4.38.2 - triton: 2.2.0 - types-python-dateutil: 2.8.19.20240311 - typing-extensions: 4.9.0 - typing-utils: 0.1.0 - tzdata: 2024.1 - uri-template: 1.3.0 - urllib3: 2.1.0 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - wheel: 0.41.2 - widgetsnbextension: 4.0.10 - xxhash: 3.4.1 - yarl: 1.9.4 - zipp: 3.17.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.8 - release: 5.15.0-107-generic - version: #117~20.04.1-Ubuntu SMP Tue Apr 30 10:35:57 UTC 2024

More info

No response

LawJarp-A commented 4 months ago

@carlos-havier As mentioned in the PyTorch Lightning documentation, when using ddp_notebook, the downside is:

"GPU operations such as moving tensors to the GPU or calling torch.cuda functions before invoking Trainer.fit is not allowed."

This means that there can be no CUDA tensors before calling Trainer.fit. By default, when training, PyTorch Lightning saves the state_dict of the trainer as CUDA when using GPU. So when load from checkpoint, CUDA is initialised. You can verify this with a simple check:

print(torch.cuda.is_initialized())

This can be placed before and after calling:

pl_model = LT_timm_model.load_from_checkpoint(def_log_chkpt)

You'll observe that CUDA is initialized when calling load_from_checkpoint, and once CUDA is initialized here, it cannot be re-initialized in a different context as required by ddp_notebook.

The Fix:

Use map_location as CPU when calling load_from_checkpoint:

pl_model = LT_timm_model.load_from_checkpoint(def_log_chkpt, map_location=torch.device('cpu'))

For your reference, I have added a few lines to debug based on your notebook here.