Open carlos-havier opened 5 months ago
@carlos-havier As mentioned in the PyTorch Lightning documentation, when using ddp_notebook
, the downside is:
"GPU operations such as moving tensors to the GPU or calling torch.cuda functions before invoking
Trainer.fit
is not allowed."
This means that there can be no CUDA tensors before calling Trainer.fit
. By default, when training, PyTorch Lightning saves the state_dict
of the trainer as CUDA when using GPU. So when load from checkpoint, CUDA is initialised. You can verify this with a simple check:
print(torch.cuda.is_initialized())
This can be placed before and after calling:
pl_model = LT_timm_model.load_from_checkpoint(def_log_chkpt)
You'll observe that CUDA is initialized when calling load_from_checkpoint
, and once CUDA is initialized here, it cannot be re-initialized in a different context as required by ddp_notebook
.
The Fix:
Use map_location
as CPU when calling load_from_checkpoint
:
pl_model = LT_timm_model.load_from_checkpoint(def_log_chkpt, map_location=torch.device('cpu'))
For your reference, I have added a few lines to debug based on your notebook here.
Bug description
Lightning throws an error when using a saved model to run inference, while using the ddp_notebook strategy.
In this case, it throws the error: "RuntimeError: Lightning can't create new processes if CUDA is already initialized. Did you manually call torch.cuda.* functions, have moved the model to the device, or allocated memory on the GPU any other way? Please remove any such calls, or change the selected strategy. You will have to restart the Python kernel."
I submit a minimum working example to reproduce the error.
What version are you seeing the problem on?
v2.1
How to reproduce the bug
Error messages and logs
Environment
Current environment
* CUDA: - GPU: - NVIDIA GeForce RTX 3080 Laptop GPU - available: True - version: 11.8 * Lightning: - lightning: 2.2.1 - lightning-utilities: 0.10.1 - pytorch-lightning: 2.1.3 - torch: 2.2.1 - torchaudio: 2.2.1 - torchmetrics: 1.2.1 - torchvision: 0.17.1 * Packages: - aiohttp: 3.9.3 - aiosignal: 1.3.1 - alembic: 1.13.1 - anyio: 4.3.0 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - async-lru: 2.0.4 - attrs: 23.2.0 - babel: 2.14.0 - beautifulsoup4: 4.12.3 - bleach: 6.1.0 - brotli: 1.0.9 - cached-property: 1.5.2 - certifi: 2024.2.2 - cffi: 1.16.0 - charset-normalizer: 2.0.4 - colorama: 0.4.6 - colorlog: 6.8.2 - comm: 0.2.1 - contourpy: 1.2.0 - cycler: 0.12.1 - datasets: 2.18.0 - debugpy: 1.8.1 - decorator: 5.1.1 - defusedxml: 0.7.1 - dill: 0.3.8 - entrypoints: 0.4 - exceptiongroup: 1.2.0 - executing: 2.0.1 - fastjsonschema: 2.19.1 - filelock: 3.13.1 - fonttools: 4.49.0 - fqdn: 1.5.1 - frozenlist: 1.4.1 - fsspec: 2024.2.0 - ftfy: 6.1.3 - gmpy2: 2.1.2 - greenlet: 3.0.3 - h11: 0.14.0 - h2: 4.1.0 - hpack: 4.0.0 - httpcore: 1.0.4 - httpx: 0.27.0 - huggingface-hub: 0.21.3 - hyperframe: 6.0.1 - idna: 3.4 - importlib-metadata: 7.0.1 - importlib-resources: 6.1.2 - ipykernel: 6.29.3 - ipython: 8.22.2 - ipywidgets: 8.1.2 - isoduration: 20.11.0 - jedi: 0.19.1 - jinja2: 3.1.3 - joblib: 1.3.2 - json5: 0.9.22 - jsonpointer: 2.4 - jsonschema: 4.21.1 - jsonschema-specifications: 2023.12.1 - jupyter-client: 8.6.0 - jupyter-core: 5.7.1 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.4 - jupyter-server: 2.13.0 - jupyter-server-terminals: 0.5.2 - jupyterlab: 4.1.4 - jupyterlab-pygments: 0.3.0 - jupyterlab-server: 2.25.4 - jupyterlab-widgets: 3.0.10 - kiwisolver: 1.4.5 - lightning: 2.2.1 - lightning-utilities: 0.10.1 - mako: 1.3.2 - markupsafe: 2.1.3 - matplotlib: 3.8.3 - matplotlib-inline: 0.1.6 - mistune: 3.0.2 - mkl-fft: 1.3.8 - mkl-random: 1.2.4 - mkl-service: 2.4.0 - mpmath: 1.3.0 - multidict: 6.0.5 - multiprocess: 0.70.16 - munkres: 1.1.4 - nbclient: 0.8.0 - nbconvert: 7.16.2 - nbformat: 5.9.2 - nest-asyncio: 1.6.0 - networkx: 3.1 - notebook-shim: 0.2.4 - numpy: 1.26.4 - optuna: 3.5.0 - overrides: 7.7.0 - p-tqdm: 1.4.0 - packaging: 23.2 - pandas: 2.2.1 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathos: 0.3.2 - patsy: 0.5.6 - pexpect: 4.9.0 - pickleshare: 0.7.5 - pillow: 10.2.0 - pip: 23.3.1 - pkgutil-resolve-name: 1.3.10 - platformdirs: 4.2.0 - ply: 3.11 - pox: 0.3.4 - ppft: 1.7.6.8 - prometheus-client: 0.20.0 - prompt-toolkit: 3.0.42 - psutil: 5.9.8 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 15.0.0 - pyarrow-hotfix: 0.6 - pycparser: 2.21 - pygments: 2.17.2 - pyparsing: 3.1.1 - pyqt5: 5.15.10 - pyqt5-sip: 12.13.0 - pysocks: 1.7.1 - python-dateutil: 2.9.0 - python-json-logger: 2.0.7 - pytorch-lightning: 2.1.3 - pytz: 2024.1 - pyyaml: 6.0.1 - pyzmq: 25.1.2 - referencing: 0.33.0 - regex: 2023.12.25 - requests: 2.31.0 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rpds-py: 0.18.0 - safetensors: 0.4.2 - scikit-learn: 1.4.1.post1 - scipy: 1.12.0 - seaborn: 0.13.2 - send2trash: 1.8.2 - setuptools: 68.2.2 - sip: 6.7.12 - six: 1.16.0 - sniffio: 1.3.1 - soupsieve: 2.5 - sqlalchemy: 2.0.28 - stack-data: 0.6.2 - statsmodels: 0.14.1 - sympy: 1.12 - terminado: 0.18.0 - threadpoolctl: 3.3.0 - timm: 0.9.16 - tinycss2: 1.2.1 - tokenizers: 0.15.2 - tomli: 2.0.1 - torch: 2.2.1 - torchaudio: 2.2.1 - torchmetrics: 1.2.1 - torchvision: 0.17.1 - tornado: 6.4 - tqdm: 4.66.2 - traitlets: 5.14.1 - transformers: 4.38.2 - triton: 2.2.0 - types-python-dateutil: 2.8.19.20240311 - typing-extensions: 4.9.0 - typing-utils: 0.1.0 - tzdata: 2024.1 - uri-template: 1.3.0 - urllib3: 2.1.0 - wcwidth: 0.2.13 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.7.0 - wheel: 0.41.2 - widgetsnbextension: 4.0.10 - xxhash: 3.4.1 - yarl: 1.9.4 - zipp: 3.17.0 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.8 - release: 5.15.0-107-generic - version: #117~20.04.1-Ubuntu SMP Tue Apr 30 10:35:57 UTC 2024More info
No response