Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.47k stars 3.3k forks source link

Downloading artifacts with wandblogger in DDP case failing on non-zero rank processes #19003

Open galbraun opened 7 months ago

galbraun commented 7 months ago

Bug description

In case using Wandblogger download_artifact function in a DDP case with multiple GPUs - the artifact won't be downloaded in other processes beside the 0-rank process. The function wrapped with the decorator rank_zero_only and thus returning None and not executing the method.

What version are you seeing the problem on?

v2.1

How to reproduce the bug

from lightning.pytorch.loggers import WandbLogger
import lightning as L
import os
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
from torch import utils

logger = WandbLogger()

artifact_path = logger.download_artifact(<artifact_name>)

trainer = L.Trainer(logger=logger, accelerator='gpu', devices=[0,1,2])

dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_dataloader = utils.data.DataLoader(dataset)

model = MyLightningModule()
trainer.fit(model, train_dataloader)

# The output on the 0-rank process will be the real path but in other processes will be None
print(artifact_path)

Error messages and logs

# Error messages and logs here please

Environment

Current environment * CUDA: - GPU: - NVIDIA A10G - NVIDIA A10G - NVIDIA A10G - NVIDIA A10G - available: True - version: 12.1 * Lightning: - lightning: 2.1.0 - lightning-utilities: 0.9.0 - pytorch-lightning: 2.1.0 - torch: 2.1.0 - torchmetrics: 1.2.0 - torchvision: 0.16.0 * Packages: - aiohttp: 3.8.6 - aiosignal: 1.3.1 - annotated-types: 0.6.0 - anyio: 4.0.0 - appdirs: 1.4.4 - argon2-cffi: 23.1.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.3.0 - asttokens: 2.4.1 - async-lru: 2.0.4 - async-timeout: 4.0.3 - attrs: 23.1.0 - babel: 2.13.1 - beautifulsoup4: 4.12.2 - bleach: 6.1.0 - bokeh: 3.3.0 - boto3: 1.28.79 - botocore: 1.31.79 - certifi: 2023.7.22 - cffi: 1.16.0 - charset-normalizer: 3.3.2 - click: 8.1.7 - colorcet: 3.0.1 - comm: 0.2.0 - compress-pickle: 2.1.0 - contourpy: 1.2.0 - cython: 0.29.36 - datasets: 2.14.6 - debugpy: 1.8.0 - decorator: 5.1.1 - defusedxml: 0.7.1 - dill: 0.3.7 - docker-pycreds: 0.4.0 - executing: 2.0.1 - fastjsonschema: 2.18.1 - filelock: 3.13.1 - fqdn: 1.5.1 - frozenlist: 1.4.0 - fsspec: 2023.10.0 - gitdb: 4.0.11 - gitpython: 3.1.40 - hdbscan: 0.8.33 - holoviews: 1.18.0 - huggingface-hub: 0.17.3 - idna: 3.4 - ipykernel: 6.26.0 - ipython: 8.17.2 - isoduration: 20.11.0 - jedi: 0.19.1 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.2 - json5: 0.9.14 - jsonpointer: 2.4 - jsonschema: 4.19.2 - jsonschema-specifications: 2023.7.1 - jupyter-client: 8.6.0 - jupyter-core: 5.5.0 - jupyter-events: 0.9.0 - jupyter-lsp: 2.2.0 - jupyter-server: 2.10.0 - jupyter-server-terminals: 0.4.4 - jupyterlab: 4.0.8 - jupyterlab-pygments: 0.2.2 - jupyterlab-server: 2.25.0 - lightning: 2.1.0 - lightning-utilities: 0.9.0 - linkify-it-py: 2.0.2 - llvmlite: 0.41.1 - lz4: 4.3.2 - markdown: 3.5.1 - markdown-it-py: 3.0.0 - markupsafe: 2.1.3 - matplotlib-inline: 0.1.6 - mdit-py-plugins: 0.4.0 - mdurl: 0.1.2 - mistune: 3.0.2 - mpmath: 1.3.0 - multidict: 6.0.4 - multiprocess: 0.70.15 - nbclient: 0.9.0 - nbconvert: 7.11.0 - nbformat: 5.9.2 - nest-asyncio: 1.5.8 - networkx: 3.2.1 - nltk: 3.8.1 - notebook: 7.0.6 - notebook-shim: 0.2.3 - numba: 0.58.1 - numpy: 1.26.1 - nvidia-cublas-cu12: 12.1.3.1 - nvidia-cuda-cupti-cu12: 12.1.105 - nvidia-cuda-nvrtc-cu12: 12.1.105 - nvidia-cuda-runtime-cu12: 12.1.105 - nvidia-cudnn-cu12: 8.9.2.26 - nvidia-cufft-cu12: 11.0.2.54 - nvidia-curand-cu12: 10.3.2.106 - nvidia-cusolver-cu12: 11.4.5.107 - nvidia-cusparse-cu12: 12.1.0.106 - nvidia-nccl-cu12: 2.18.1 - nvidia-nvjitlink-cu12: 12.3.52 - nvidia-nvtx-cu12: 12.1.105 - overrides: 7.4.0 - packaging: 23.2 - pandas: 2.1.2 - pandocfilters: 1.5.0 - panel: 1.3.1 - param: 2.0.0 - parso: 0.8.3 - pexpect: 4.8.0 - pillow: 10.1.0 - pip: 23.3 - platformdirs: 3.11.0 - prometheus-client: 0.18.0 - prompt-toolkit: 3.0.39 - protobuf: 4.25.0 - psutil: 5.9.6 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 14.0.0 - pycparser: 2.21 - pyct: 0.5.0 - pydantic: 2.4.2 - pydantic-core: 2.10.1 - pydantic-numpy: 4.0.0 - pygments: 2.16.1 - pynndescent: 0.5.10 - python-dateutil: 2.8.2 - python-json-logger: 2.0.7 - pytorch-lightning: 2.1.0 - pytz: 2023.3.post1 - pyviz-comms: 3.0.0 - pyyaml: 6.0.1 - pyzmq: 25.1.1 - quiver: 0.0.2 - referencing: 0.30.2 - regex: 2023.10.3 - requests: 2.31.0 - requests-file: 1.5.1 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rpds-py: 0.12.0 - ruamel.yaml: 0.17.40 - ruamel.yaml.clib: 0.2.8 - s3transfer: 0.7.0 - safetensors: 0.4.0 - scikit-learn: 1.3.2 - scipy: 1.11.3 - semver: 3.0.2 - send2trash: 1.8.2 - sentence-transformers: 2.2.2 - sentencepiece: 0.1.99 - sentry-sdk: 1.34.0 - setproctitle: 1.3.3 - setuptools: 68.0.0 - six: 1.16.0 - smmap: 5.0.1 - sniffio: 1.3.0 - soupsieve: 2.5 - stack-data: 0.6.3 - sympy: 1.12 - tbb: 2021.10.0 - terminado: 0.17.1 - threadpoolctl: 3.2.0 - tinycss2: 1.2.1 - tldextract: 5.1.0 - tokenizers: 0.14.1 - torch: 2.1.0 - torchmetrics: 1.2.0 - torchvision: 0.16.0 - tornado: 6.3.3 - tqdm: 4.66.1 - traitlets: 5.13.0 - transformers: 4.35.0 - triton: 2.1.0 - types-python-dateutil: 2.8.19.14 - typing-extensions: 4.8.0 - tzdata: 2023.3 - uc-micro-py: 1.0.2 - umap-learn: 0.5.4 - uri-template: 1.3.0 - urllib3: 2.0.7 - wandb: 0.16.0 - wcwidth: 0.2.9 - webcolors: 1.13 - webencodings: 0.5.1 - websocket-client: 1.6.4 - wheel: 0.41.2 - xgboost: 2.0.1 - xxhash: 3.4.1 - xyzservices: 2023.10.1 - yarl: 1.9.2 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.11.5 - release: 6.2.0-1015-aws - version: #15~22.04.1-Ubuntu SMP Fri Oct 6 21:37:24 UTC 2023

More info

No response

cc @awaelchli @morganmcg1 @borisdayma @scottire @parambharat

nate-wandb commented 7 months ago

Hi @galbraun, as a workaround would it work to set the directory ahead of time that you want the Artifact to be stored at so then all ranks are aware of where the Artifact is?

For example:

art_path = "some/path"
if rank==0:
    download_artifact(<artifact_name>, save_dir=art_path)
print(art_path)
galbraun commented 7 months ago

Yes thanks, that a workaround I'm currently using, but for now I'm using a single node setup and if I understand correctly it won't work in a multi-node case, no?

Anyhow I would expect to have some kind of a different handling here, or maybe at least a warning or a comment about it in the documentation, all other methods in the logger with this decorator indeed return None for both cases, this method expected to return a string.

nzw0301 commented 1 month ago

How about using local_rank == 0 instead of global_rank to download artifacts on every node once.