Prepare data in GPU device in DataModule/DataSet #18124

stonelazy commented 11 months ago

Bug description

tldr; Create a tensor - mount to GPU during dataloading - do some operation on this - throw re-initialize cuda error.

I want to do some cuda operations in my Dataloader. I need do execute these cuda operations to prepare a batched data i.e even before lightning invokes transfer_batch_to_device. But am facing an error when i invoke torch.cuda.*

  RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method

Note - My training config. Strategy=DDP, DataLoader=num_workers > 1

I have referred various threads that discusses on this specific problem, Ref1, Ref2, Ref3. General advise is that, i should avoid invoking torch.cuda.* functions before being invoked. I have been following the same by keeping the usage within my, but unfortunately even them am facing this issue. I even tried doing this operation conditionally post few iterations for just ensuring that my code doesn't initialize CUDA process but still the same issue.

What version are you seeing the problem on?


How to reproduce the bug

class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length = torch.randn(length, size)

    def _do_some_augmentation(self):
        I want to be doing some heavy lifting of the input that is being loaded. I would want to do this in GPU device.
        print("rank ", dist.get_rank())
        return torch.randn(1,2).cuda() # This line throws the error.

    def __getitem__(self, index):
        self._do_some_augmentation() # If this line is commented, it will successfully execute.

    def __len__(self):
        return self.len

NOTE - I am not running this in Jupyter notebook, but as a python module.


Current environment * CUDA: - GPU: - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - available: True - version: 11.8 * Lightning: - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.0 - torch: 2.0.0+cu118 - torch-poly-lr-decay: 0.0.1 - torch-stoi: 0.1.2 - torchaudio: 2.0.0+cu118 - torchmetrics: 0.11.4 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.4 - aiohttp-retry: 2.8.3 - aioice: 0.7.6 - aiortc: 1.4.0 - aiosignal: 1.3.1 - alembic: 1.11.1 - altair: 5.0.1 - amqp: 5.1.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.1 - appdirs: 1.4.4 - argbind: 0.3.7 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asgiref: 3.7.2 - astroid: 2.15.6 - asttokens: 2.2.1 - async-timeout: 4.0.2 - asyncssh: 2.13.1 - atpublic: 3.1.1 - attrs: 23.1.0 - audio-events-classification: 0.1.2 - audioread: 3.0.0 - auraloss: 0.4.0 - autopage: 0.5.1 - av: 10.0.0 - backcall: 0.2.0 - beautifulsoup4: 4.12.2 - billiard: - black: 23.7.0 - bleach: 6.0.0 - blinker: 1.6.2 - blis: 0.7.9 - boto3: 1.28.2 - botocore: 1.31.2 - braceexpand: 0.1.7 - bravado: 11.0.3 - bravado-core: 5.17.1 - cachetools: 5.3.1 - catalogue: 2.0.8 - cdifflib: 1.2.6 - celery: 5.2.7 - certifi: 2023.5.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.2.0 - clang-format: 15.0.7 - click: 8.1.4 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.2.0 - cliff: 4.3.0 - cmaes: 0.9.1 - cmake: 3.26.4 - cmd2: 2.4.3 - collection: 0.1.6 - colorama: 0.4.6 - coloredlogs: 15.0.1 - colorlog: 6.7.0 - comm: 0.1.3 - confection: 0.1.0 - configobj: 5.0.8 - contourpy: 1.1.0 - coverage: 7.2.7 - cryptography: 40.0.1 - ctcdecode: 1.0.5 - cycler: 0.11.0 - cymem: 2.0.7 - cython: 0.29.36 - cytoolz: 0.12.1 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - deprecated: 1.2.14 - descript-audiotools: 0.7.1 - dictdiffer: 0.9.0 - dill: 0.3.6 - dirhash: 0.2.1 - diskcache: 5.6.1 - distlib: 0.3.6 - distro: 1.8.0 - dnspython: 2.3.0 - docstring-parser: 0.15 - dpath: 2.1.5 - dulwich: 0.21.5 - dvc: 2.41.1 - dvc-data: 0.29.0 - dvc-http: 2.30.2 - dvc-objects: 0.14.1 - dvc-render: 0.0.17 - dvc-stratus: 0.3.1 - dvc-studio-client: 0.9.2 - dvc-task: 0.1.9 - dvclive: 2.0.2 - edit-distance: 1.0.6 - editdistance: 0.6.2 - entrypoints: 0.4 - exceptiongroup: 1.1.2 - executing: 1.2.0 - fastapi: 0.70.1 - fastjsonschema: 2.17.1 - ffmpeg-python: 0.2.0 - ffmpy: 0.3.0 - filelock: 3.12.2 - fire: 0.5.0 - flatbuffers: 23.5.26 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.40.0 - fqdn: 1.5.1 - frozendict: 2.3.8 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - ftfy: 5.9 - funcy: 2.0 - future: 0.18.3 - gitdb: 4.0.10 - gitpython: 3.1.32 - google-auth: 2.22.0 - google-auth-oauthlib: 1.0.0 - google-crc32c: 1.5.0 - grandalf: 0.6 - greenlet: 2.0.2 - grpcio: 1.56.0 - gunicorn: 20.1.0 - h11: 0.14.0 - huggingface-hub: 0.16.4 - humanfriendly: 10.0 - hydra-core: 1.3.2 - identify: 2.5.24 - idna: 3.4 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.0 - infinibatch: 0.1.0 - inflect: 7.0.0 - iniconfig: 2.0.0 - ipykernel: 6.23.2 - ipython: 8.14.0 - ipython-genutils: 0.2.0 - isoduration: 20.11.0 - isort: 5.12.0 - iterative-telemetry: 0.0.6 - jedi: 0.18.2 - jellyfish: 1.0.0 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.1 - jsonformatter: 0.3.2 - jsonpointer: 2.4 - jsonref: 1.1.0 - jsonschema: 4.18.1 - jsonschema-specifications: 2023.6.1 - julius: 0.2.7 - jupyter-client: 8.2.0 - jupyter-contrib-core: 0.4.2 - jupyter-contrib-nbextensions: 0.7.0 - jupyter-core: 5.3.0 - jupyter-events: 0.6.3 - jupyter-highlight-selected-word: 0.2.0 - jupyter-nbextensions-configurator: 0.6.3 - jupyter-server: 2.6.0 - jupyter-server-terminals: 0.4.4 - jupyterlab-pygments: 0.2.2 - kiwisolver: 1.4.4 - kombu: 5.2.4 - langcodes: 3.3.0 - lazy-object-proxy: 1.9.0 - librosa: 0.9.2 - lightning-utilities: 0.9.0 - limits: 3.5.0 - lit: 16.0.6 - llvmlite: 0.40.1 - lxml: 4.9.2 - mako: 1.2.4 - markdown: 3.4.3 - markdown-it-py: 3.0.0 - markdown2: 2.4.9 - markupsafe: 2.1.3 - matplotlib: 3.7.2 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdurl: 0.1.2 - mistune: 2.0.5 - monotonic: 1.6 - more-itertools: 9.1.0 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.4 - murmurhash: 1.0.9 - mypy: 0.961 - mypy-extensions: 1.0.0 - nanotime: 0.5.2 - nbclassic: 1.0.0 - nbclient: 0.8.0 - nbconvert: 7.4.0 - nbformat: 5.9.0 - nemo-text-processing: 0.1.8rc0 - nemo-toolkit: 1.9.0 - neptune-client: 0.16.18 - nest-asyncio: 1.5.6 - netifaces: 0.11.0 - networkx: 3.1 - nodeenv: 1.8.0 - notebook: 6.5.4 - notebook-shim: 0.2.3 - numba: 0.57.1 - numpy: 1.23.5 - nvidia-ml-py: 11.525.131 - nvitop: 1.1.2 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - onnx: 1.12.0 - onnxruntime-gpu: 1.13.1 - openai-whisper: 20230314 - optuna: 2.10.1 - overrides: 7.3.1 - packaging: 23.1 - pandas: 2.0.3 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathspec: 0.9.0 - pathy: 0.10.2 - pbr: 5.11.1 - pedalboard: 0.7.4 - pesq: 0.0.4 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.0.0 - pip: 22.3.1 - platformdirs: 3.8.1 - pluggy: 1.2.0 - pooch: 1.7.0 - pre-commit: 3.3.3 - preshed: 3.0.8 - prettytable: 3.8.0 - prometheus-client: 0.17.0 - prompt-toolkit: 3.0.39 - protobuf: 3.20.1 - psutil: 5.9.5 - ptflops: 0.7 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 12.0.1 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pybind11: 2.10.4 - pycparser: 2.21 - pycryptodome: 3.18.0 - pydantic: 1.10.11 - pydeck: 0.8.1b0 - pydot: 1.4.2 - pyee: 9.0.4 - pygit2: 1.12.1 - pygments: 2.15.1 - pygtrie: 2.5.0 - pyjwt: 2.7.0 - pylibsrtp: 0.8.0 - pylint: 2.17.4 - pylint-protobuf: 0.20.2 - pyloudnorm: 0.1.1 - pympler: 1.0.1 - pynini: 2.1.5 - pyopenssl: 23.1.1 - pyparsing: 3.0.9 - pyperclip: 1.8.2 - pyphen: 0.14.0 - pyroomacoustics: 0.5.0 - pyrsistent: 0.19.3 - pystoi: 0.3.3 - pystratus: 0.2.2 - pyte: 0.8.0 - pytest: 7.4.0 - pytest-cov: 4.1.0 - pytest-mock: 3.11.1 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-json-logger: 2.0.7 - python-magic: 0.4.27 - python-multipart: 0.0.4 - pytorch-lightning: 2.0.0 - pytz: 2023.3 - pytz-deprecation-shim: 0.1.0.post0 - pyyaml: 6.0 - pyzmq: 25.1.0 - randomname: 0.2.1 - referencing: 0.29.1 - regex: 2023.6.3 - registrable: 0.0.4 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - resampy: 0.4.2 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rfc3987: 1.3.8 - rich: 13.4.2 - rpds-py: 0.8.10 - rsa: 4.9 - ruamel.yaml: 0.17.26 - ruamel.yaml.clib: 0.2.7 - s3transfer: 0.6.1 - sacremoses: 0.0.53 - scantree: 0.0.1 - scikit-learn: 1.3.0 - scipy: 1.9.3 - scmrepo: 0.1.5 - seaborn: 0.12.2 - semver: 3.0.1 - send2trash: 1.8.2 - setuptools: 65.6.3 - setuptools-scm: 7.1.0 - shortuuid: 1.0.11 - shtab: 1.6.1 - simplejson: 3.19.1 - six: 1.16.0 - slowapi: 0.1.8 - smart-open: 6.3.0 - smmap: 5.0.0 - sniffio: 1.3.0 - soundfile: 0.10.3.post1 - soupsieve: 2.4.1 - sox: 1.4.1 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.4 - sqlalchemy: 2.0.18 - srsly: 2.4.6 - srt: 3.5.3 - stack-data: 0.6.2 - starlette: 0.16.0 - stevedore: 5.1.0 - streamlit: 1.19.0 - streamlit-webrtc: 0.44.7 - swagger-spec-validator: 3.0.3 - sympy: 1.12 - tabulate: 0.9.0 - taskipy: 1.11.0 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.1 - termcolor: 2.3.0 - terminado: 0.17.1 - thefuck: 3.32 - thinc: 8.1.10 - thop: 0.1.1.post2209072238 - threadpoolctl: 3.1.0 - tiktoken: 0.3.3 - tinycss2: 1.2.1 - tokenize-rt: 5.1.0 - tokenizers: 0.10.3 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.8 - toolz: 0.11.2 - torch: 2.0.0+cu118 - torch-poly-lr-decay: 0.0.1 - torch-stoi: 0.1.2 - torchaudio: 2.0.0+cu118 - torchmetrics: 0.11.4 - tornado: 6.3.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - transformers: 4.16.2 - triton: 2.0.0 - typer: 0.9.0 - typing-extensions: 4.7.1 - tzdata: 2023.3 - tzlocal: 5.0.1 - unidecode: 1.3.6 - uri-template: 1.3.0 - urllib3: 1.26.15 - uvicorn: 0.15.0 - validators: 0.20.0 - vine: 5.0.0 - virtualenv: 20.23.1 - voluptuous: 0.13.1 - wasabi: 1.1.2 - watchdog: 3.0.0 - wcwidth: 0.2.6 - webcolors: 1.13 - webdataset: 0.1.103 - webencodings: 0.5.1 - webrtcvad: 2.0.10 - websocket-client: 1.6.1 - werkzeug: 2.3.6 - wget: 3.2 - wheel: 0.37.1 - wrapt: 1.15.0 - yarl: 1.9.2 - youtube-dl: 2021.2.22 - zc.lockfile: 3.0.post1 - zdnsmos: 0.1.4 - zipp: 3.16.0 - zlogs: 0.2.4 - zspeech: 0.2.2 - zspeech-analytics: 2.0 - zspeech-api: 0.2.2 - zspeech-audio: 0.2.2 - zspeech-audio-common: 0.2.2 - zspeech-audio-transforms: 0.2.2 - zspeech-common: 0.2.2 - zspeech-dev: 0.2.2 - zspeech-inference: 0.2.2 - zspeech-inference-client: 0.2.2 - zspeech-models: 0.2.2 - zspeech-nn: 0.2.2 - zwaf-isc: 0.2.1 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.16 - release: 5.15.0-76-generic - version: #83~20.04.1-Ubuntu SMP Wed Jun 21 20:23:31 UTC 2023

More info

Why can't i do this after Lightning transfers the data to device (on_after_batch_transfer) ?

What do i want ? If this is a known behaviour, would like to know what is the recommended way to do these data based operations in a GPU within my DataLoader/DataSet before the automatic device transfer by Lightning.

cc @justusschock @awaelchli

awaelchli commented 11 months ago

@stonelazy The issue here is that at the time the Trainer calls the dataloader iterator, the process has already initialized CUDA. The worker processes won't be able to init CUDA for your operations. Performing CUDA operations in dataloader workers is not recommended by PyTorch. There is a warning on this page:

It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing). Instead, we recommend using automatic memory pinning (i.e., setting pin_memory=True), which enables fast data transfer to CUDA-enabled GPUs.

I recommend that you do the processing on the batch returned from the loader instad. For example, you could implement the transfer_batch_to_device hook in the LightningModule and perform your processing there:

class MyModel(LightningModule):
    def transfer_batch_to_device(self, batch, device, dataloader_idx):
         # process CPU batch
         # move to GPU
         # transform GPU batch
         # etc.
         return processed_batch    

Could this be a viable option for you?

stonelazy commented 11 months ago

@awaelchli Thanks for getting back, i already thought about this and because of some practical difficulties i didn't want to go ahead with that. Now that you have made it clear that it isn't possible to do Cuda operations any other way, could you please help me with these queries. 1.)

 def transfer_batch_to_device(self, batch, device, dataloader_idx):
         # 1.a) Move to GPU and Transform
         # 1.b) Move to CPU and Transform
         # 1.c) Move to GPU 
         return processed_batch    

1.) Is it ok if we have some long running logic and also modifying data in this method transfer_batch_to_device - The reason am asking is because of this note in the docs

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing).

Note - Problem that i see with my overall sequence of operations is that i had to go a GPU task first and then CPU ! (Weird, yes).

2.) Suppose i use transfer_batch_to_device in the way you are suggesting - CPU task that i do requires that i send single sample as input and not batched data, in that case sequence length may vary post the transformation (b/w each samples) - which means when stitching the output together i need to invoke collate_fn manually on my own. Is this fine ?

JiayiShou commented 2 months ago

I have the exact same problem. I wonder if you have found a way to solve it....