Open stonelazy opened 1 year ago
@stonelazy The issue here is that at the time the Trainer calls the dataloader iterator, the process has already initialized CUDA. The worker processes won't be able to init CUDA for your operations. Performing CUDA operations in dataloader workers is not recommended by PyTorch. There is a warning on this page:
It is generally not recommended to return CUDA tensors in multi-process loading because of many subtleties in using CUDA and sharing CUDA tensors in multiprocessing (see CUDA in multiprocessing). Instead, we recommend using automatic memory pinning (i.e., setting pin_memory=True), which enables fast data transfer to CUDA-enabled GPUs.
I recommend that you do the processing on the batch returned from the loader instad. For example, you could implement the transfer_batch_to_device
hook in the LightningModule
and perform your processing there:
class MyModel(LightningModule):
...
def transfer_batch_to_device(self, batch, device, dataloader_idx):
# process CPU batch
# move to GPU
# transform GPU batch
# etc.
return processed_batch
Could this be a viable option for you?
@awaelchli Thanks for getting back, i already thought about this and because of some practical difficulties i didn't want to go ahead with that. Now that you have made it clear that it isn't possible to do Cuda operations any other way, could you please help me with these queries. 1.)
def transfer_batch_to_device(self, batch, device, dataloader_idx):
# 1.a) Move to GPU and Transform
# 1.b) Move to CPU and Transform
# 1.c) Move to GPU
return processed_batch
1.) Is it ok if we have some long running logic and also modifying data in this method transfer_batch_to_device
- The reason am asking is because of this note in the docs
This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing).
Note - Problem that i see with my overall sequence of operations is that i had to go a GPU task first and then CPU ! (Weird, yes).
2.) Suppose i use transfer_batch_to_device
in the way you are suggesting - CPU task that i do requires that i send single sample as input and not batched data, in that case sequence length may vary post the transformation (b/w each samples) - which means when stitching the output together i need to invoke collate_fn
manually on my own. Is this fine ?
I have the exact same problem. I wonder if you have found a way to solve it....
@awaelchli Thanks for getting back, i already thought about this and because of some practical difficulties i didn't want to go ahead with that. Now that you have made it clear that it isn't possible to do Cuda operations any other way, could you please help me with these queries. 1.)
def transfer_batch_to_device(self, batch, device, dataloader_idx): # 1.a) Move to GPU and Transform # 1.b) Move to CPU and Transform # 1.c) Move to GPU return processed_batch
1.) Is it ok if we have some long running logic and also modifying data in this method
transfer_batch_to_device
- The reason am asking is because of this note in the docsThis hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing).
Note - Problem that i see with my overall sequence of operations is that i had to go a GPU task first and then CPU ! (Weird, yes).
2.) Suppose i use
transfer_batch_to_device
in the way you are suggesting - CPU task that i do requires that i send single sample as input and not batched data, in that case sequence length may vary post the transformation (b/w each samples) - which means when stitching the output together i need to invokecollate_fn
manually on my own. Is this fine ?
you can go the lazy (and faster-to-prototype) route:
it's a little involved and the CPU-GPU device transfers might be a bottleneck but you'd need to do some profiling to quantify what are the bottlenecks in this data processing pipeline.
I think the only ways to truly optimize this would be:
1) preprocess whole dataset & store on disk. assuming your data doesn't change much and you're iterating different architectures or model hyperparameters, this would be well worth the initial sunk cost
2) find a way to change your transform logic so that the CPU ops can run first, followed by GPU. then, have all the CPU ops inside the Dataset.__getitem__()
function and set num_workers > 1
so that pytorch can multi-process these CPU ops for you. you would still do the GPU ops on the half-processed batch returned by the Dataloader
, outside of the Dataset/Dataloader
since CUDA doesn't gel well with Dataset multiprocessing
hope this helps.
I have the exact same problem. I wonder if you have found a way to solve it....
please see my comment
disclaimer: not from lightning AI team, but I use pytorch a lot
Bug description
tldr; Create a tensor - mount to GPU during dataloading - do some operation on this - throw re-initialize cuda error.
I want to do some cuda operations in my Dataloader. I need do execute these cuda operations to prepare a batched data i.e even before lightning invokes transfer_batch_to_device. But am facing an error when i invoke
torch.cuda.*
Note - My training config. Strategy=DDP, DataLoader=num_workers > 1
I have referred various threads that discusses on this specific problem, Ref1, Ref2, Ref3. General advise is that, i should avoid invoking
torch.cuda.*
functions beforetrainer.fit
being invoked. I have been following the same by keeping the usage within mytorch.utils.data.DataSet
, but unfortunately even them am facing this issue. I even tried doing this operation conditionally post few iterations for just ensuring that my code doesn't initialize CUDA process but still the same issue.What version are you seeing the problem on?
v2.0
How to reproduce the bug
NOTE - I am not running this in Jupyter notebook, but as a python module.
Environment
Current environment
* CUDA: - GPU: - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - NVIDIA RTX A6000 - available: True - version: 11.8 * Lightning: - lightning-utilities: 0.9.0 - pytorch-lightning: 2.0.0 - torch: 2.0.0+cu118 - torch-poly-lr-decay: 0.0.1 - torch-stoi: 0.1.2 - torchaudio: 2.0.0+cu118 - torchmetrics: 0.11.4 * Packages: - absl-py: 1.4.0 - aiohttp: 3.8.4 - aiohttp-retry: 2.8.3 - aioice: 0.7.6 - aiortc: 1.4.0 - aiosignal: 1.3.1 - alembic: 1.11.1 - altair: 5.0.1 - amqp: 5.1.1 - antlr4-python3-runtime: 4.9.3 - anyio: 3.7.1 - appdirs: 1.4.4 - argbind: 0.3.7 - argon2-cffi: 21.3.0 - argon2-cffi-bindings: 21.2.0 - arrow: 1.2.3 - asgiref: 3.7.2 - astroid: 2.15.6 - asttokens: 2.2.1 - async-timeout: 4.0.2 - asyncssh: 2.13.1 - atpublic: 3.1.1 - attrs: 23.1.0 - audio-events-classification: 0.1.2 - audioread: 3.0.0 - auraloss: 0.4.0 - autopage: 0.5.1 - av: 10.0.0 - backcall: 0.2.0 - beautifulsoup4: 4.12.2 - billiard: 3.6.4.0 - black: 23.7.0 - bleach: 6.0.0 - blinker: 1.6.2 - blis: 0.7.9 - boto3: 1.28.2 - botocore: 1.31.2 - braceexpand: 0.1.7 - bravado: 11.0.3 - bravado-core: 5.17.1 - cachetools: 5.3.1 - catalogue: 2.0.8 - cdifflib: 1.2.6 - celery: 5.2.7 - certifi: 2023.5.7 - cffi: 1.15.1 - cfgv: 3.3.1 - charset-normalizer: 3.2.0 - clang-format: 15.0.7 - click: 8.1.4 - click-didyoumean: 0.3.0 - click-plugins: 1.1.1 - click-repl: 0.2.0 - cliff: 4.3.0 - cmaes: 0.9.1 - cmake: 3.26.4 - cmd2: 2.4.3 - collection: 0.1.6 - colorama: 0.4.6 - coloredlogs: 15.0.1 - colorlog: 6.7.0 - comm: 0.1.3 - confection: 0.1.0 - configobj: 5.0.8 - contourpy: 1.1.0 - coverage: 7.2.7 - cryptography: 40.0.1 - ctcdecode: 1.0.5 - cycler: 0.11.0 - cymem: 2.0.7 - cython: 0.29.36 - cytoolz: 0.12.1 - debugpy: 1.6.7 - decorator: 5.1.1 - defusedxml: 0.7.1 - deprecated: 1.2.14 - descript-audiotools: 0.7.1 - dictdiffer: 0.9.0 - dill: 0.3.6 - dirhash: 0.2.1 - diskcache: 5.6.1 - distlib: 0.3.6 - distro: 1.8.0 - dnspython: 2.3.0 - docstring-parser: 0.15 - dpath: 2.1.5 - dulwich: 0.21.5 - dvc: 2.41.1 - dvc-data: 0.29.0 - dvc-http: 2.30.2 - dvc-objects: 0.14.1 - dvc-render: 0.0.17 - dvc-stratus: 0.3.1 - dvc-studio-client: 0.9.2 - dvc-task: 0.1.9 - dvclive: 2.0.2 - edit-distance: 1.0.6 - editdistance: 0.6.2 - entrypoints: 0.4 - exceptiongroup: 1.1.2 - executing: 1.2.0 - fastapi: 0.70.1 - fastjsonschema: 2.17.1 - ffmpeg-python: 0.2.0 - ffmpy: 0.3.0 - filelock: 3.12.2 - fire: 0.5.0 - flatbuffers: 23.5.26 - flatten-dict: 0.4.2 - flufl.lock: 7.1.1 - fonttools: 4.40.0 - fqdn: 1.5.1 - frozendict: 2.3.8 - frozenlist: 1.3.3 - fsspec: 2023.6.0 - ftfy: 5.9 - funcy: 2.0 - future: 0.18.3 - gitdb: 4.0.10 - gitpython: 3.1.32 - google-auth: 2.22.0 - google-auth-oauthlib: 1.0.0 - google-crc32c: 1.5.0 - grandalf: 0.6 - greenlet: 2.0.2 - grpcio: 1.56.0 - gunicorn: 20.1.0 - h11: 0.14.0 - huggingface-hub: 0.16.4 - humanfriendly: 10.0 - hydra-core: 1.3.2 - identify: 2.5.24 - idna: 3.4 - importlib-metadata: 6.8.0 - importlib-resources: 6.0.0 - infinibatch: 0.1.0 - inflect: 7.0.0 - iniconfig: 2.0.0 - ipykernel: 6.23.2 - ipython: 8.14.0 - ipython-genutils: 0.2.0 - isoduration: 20.11.0 - isort: 5.12.0 - iterative-telemetry: 0.0.6 - jedi: 0.18.2 - jellyfish: 1.0.0 - jinja2: 3.1.2 - jmespath: 1.0.1 - joblib: 1.3.1 - jsonformatter: 0.3.2 - jsonpointer: 2.4 - jsonref: 1.1.0 - jsonschema: 4.18.1 - jsonschema-specifications: 2023.6.1 - julius: 0.2.7 - jupyter-client: 8.2.0 - jupyter-contrib-core: 0.4.2 - jupyter-contrib-nbextensions: 0.7.0 - jupyter-core: 5.3.0 - jupyter-events: 0.6.3 - jupyter-highlight-selected-word: 0.2.0 - jupyter-nbextensions-configurator: 0.6.3 - jupyter-server: 2.6.0 - jupyter-server-terminals: 0.4.4 - jupyterlab-pygments: 0.2.2 - kiwisolver: 1.4.4 - kombu: 5.2.4 - langcodes: 3.3.0 - lazy-object-proxy: 1.9.0 - librosa: 0.9.2 - lightning-utilities: 0.9.0 - limits: 3.5.0 - lit: 16.0.6 - llvmlite: 0.40.1 - lxml: 4.9.2 - mako: 1.2.4 - markdown: 3.4.3 - markdown-it-py: 3.0.0 - markdown2: 2.4.9 - markupsafe: 2.1.3 - matplotlib: 3.7.2 - matplotlib-inline: 0.1.6 - mccabe: 0.7.0 - mdurl: 0.1.2 - mistune: 2.0.5 - monotonic: 1.6 - more-itertools: 9.1.0 - mpmath: 1.3.0 - msgpack: 1.0.5 - multidict: 6.0.4 - murmurhash: 1.0.9 - mypy: 0.961 - mypy-extensions: 1.0.0 - nanotime: 0.5.2 - nbclassic: 1.0.0 - nbclient: 0.8.0 - nbconvert: 7.4.0 - nbformat: 5.9.0 - nemo-text-processing: 0.1.8rc0 - nemo-toolkit: 1.9.0 - neptune-client: 0.16.18 - nest-asyncio: 1.5.6 - netifaces: 0.11.0 - networkx: 3.1 - nodeenv: 1.8.0 - notebook: 6.5.4 - notebook-shim: 0.2.3 - numba: 0.57.1 - numpy: 1.23.5 - nvidia-ml-py: 11.525.131 - nvitop: 1.1.2 - oauthlib: 3.2.2 - omegaconf: 2.3.0 - onnx: 1.12.0 - onnxruntime-gpu: 1.13.1 - openai-whisper: 20230314 - optuna: 2.10.1 - overrides: 7.3.1 - packaging: 23.1 - pandas: 2.0.3 - pandocfilters: 1.5.0 - parso: 0.8.3 - pathspec: 0.9.0 - pathy: 0.10.2 - pbr: 5.11.1 - pedalboard: 0.7.4 - pesq: 0.0.4 - pexpect: 4.8.0 - pickleshare: 0.7.5 - pillow: 10.0.0 - pip: 22.3.1 - platformdirs: 3.8.1 - pluggy: 1.2.0 - pooch: 1.7.0 - pre-commit: 3.3.3 - preshed: 3.0.8 - prettytable: 3.8.0 - prometheus-client: 0.17.0 - prompt-toolkit: 3.0.39 - protobuf: 3.20.1 - psutil: 5.9.5 - ptflops: 0.7 - ptyprocess: 0.7.0 - pure-eval: 0.2.2 - pyarrow: 12.0.1 - pyasn1: 0.5.0 - pyasn1-modules: 0.3.0 - pybind11: 2.10.4 - pycparser: 2.21 - pycryptodome: 3.18.0 - pydantic: 1.10.11 - pydeck: 0.8.1b0 - pydot: 1.4.2 - pyee: 9.0.4 - pygit2: 1.12.1 - pygments: 2.15.1 - pygtrie: 2.5.0 - pyjwt: 2.7.0 - pylibsrtp: 0.8.0 - pylint: 2.17.4 - pylint-protobuf: 0.20.2 - pyloudnorm: 0.1.1 - pympler: 1.0.1 - pynini: 2.1.5 - pyopenssl: 23.1.1 - pyparsing: 3.0.9 - pyperclip: 1.8.2 - pyphen: 0.14.0 - pyroomacoustics: 0.5.0 - pyrsistent: 0.19.3 - pystoi: 0.3.3 - pystratus: 0.2.2 - pyte: 0.8.0 - pytest: 7.4.0 - pytest-cov: 4.1.0 - pytest-mock: 3.11.1 - python-dateutil: 2.8.2 - python-dotenv: 1.0.0 - python-json-logger: 2.0.7 - python-magic: 0.4.27 - python-multipart: 0.0.4 - pytorch-lightning: 2.0.0 - pytz: 2023.3 - pytz-deprecation-shim: 0.1.0.post0 - pyyaml: 6.0 - pyzmq: 25.1.0 - randomname: 0.2.1 - referencing: 0.29.1 - regex: 2023.6.3 - registrable: 0.0.4 - requests: 2.31.0 - requests-oauthlib: 1.3.1 - resampy: 0.4.2 - rfc3339-validator: 0.1.4 - rfc3986-validator: 0.1.1 - rfc3987: 1.3.8 - rich: 13.4.2 - rpds-py: 0.8.10 - rsa: 4.9 - ruamel.yaml: 0.17.26 - ruamel.yaml.clib: 0.2.7 - s3transfer: 0.6.1 - sacremoses: 0.0.53 - scantree: 0.0.1 - scikit-learn: 1.3.0 - scipy: 1.9.3 - scmrepo: 0.1.5 - seaborn: 0.12.2 - semver: 3.0.1 - send2trash: 1.8.2 - setuptools: 65.6.3 - setuptools-scm: 7.1.0 - shortuuid: 1.0.11 - shtab: 1.6.1 - simplejson: 3.19.1 - six: 1.16.0 - slowapi: 0.1.8 - smart-open: 6.3.0 - smmap: 5.0.0 - sniffio: 1.3.0 - soundfile: 0.10.3.post1 - soupsieve: 2.4.1 - sox: 1.4.1 - spacy-legacy: 3.0.12 - spacy-loggers: 1.0.4 - sqlalchemy: 2.0.18 - srsly: 2.4.6 - srt: 3.5.3 - stack-data: 0.6.2 - starlette: 0.16.0 - stevedore: 5.1.0 - streamlit: 1.19.0 - streamlit-webrtc: 0.44.7 - swagger-spec-validator: 3.0.3 - sympy: 1.12 - tabulate: 0.9.0 - taskipy: 1.11.0 - tensorboard: 2.13.0 - tensorboard-data-server: 0.7.1 - termcolor: 2.3.0 - terminado: 0.17.1 - thefuck: 3.32 - thinc: 8.1.10 - thop: 0.1.1.post2209072238 - threadpoolctl: 3.1.0 - tiktoken: 0.3.3 - tinycss2: 1.2.1 - tokenize-rt: 5.1.0 - tokenizers: 0.10.3 - toml: 0.10.2 - tomli: 2.0.1 - tomlkit: 0.11.8 - toolz: 0.11.2 - torch: 2.0.0+cu118 - torch-poly-lr-decay: 0.0.1 - torch-stoi: 0.1.2 - torchaudio: 2.0.0+cu118 - torchmetrics: 0.11.4 - tornado: 6.3.2 - tqdm: 4.65.0 - traitlets: 5.9.0 - transformers: 4.16.2 - triton: 2.0.0 - typer: 0.9.0 - typing-extensions: 4.7.1 - tzdata: 2023.3 - tzlocal: 5.0.1 - unidecode: 1.3.6 - uri-template: 1.3.0 - urllib3: 1.26.15 - uvicorn: 0.15.0 - validators: 0.20.0 - vine: 5.0.0 - virtualenv: 20.23.1 - voluptuous: 0.13.1 - wasabi: 1.1.2 - watchdog: 3.0.0 - wcwidth: 0.2.6 - webcolors: 1.13 - webdataset: 0.1.103 - webencodings: 0.5.1 - webrtcvad: 2.0.10 - websocket-client: 1.6.1 - werkzeug: 2.3.6 - wget: 3.2 - wheel: 0.37.1 - wrapt: 1.15.0 - yarl: 1.9.2 - youtube-dl: 2021.2.22 - zc.lockfile: 3.0.post1 - zdnsmos: 0.1.4 - zipp: 3.16.0 - zlogs: 0.2.4 - zspeech: 0.2.2 - zspeech-analytics: 2.0 - zspeech-api: 0.2.2 - zspeech-audio: 0.2.2 - zspeech-audio-common: 0.2.2 - zspeech-audio-transforms: 0.2.2 - zspeech-common: 0.2.2 - zspeech-dev: 0.2.2 - zspeech-inference: 0.2.2 - zspeech-inference-client: 0.2.2 - zspeech-models: 0.2.2 - zspeech-nn: 0.2.2 - zwaf-isc: 0.2.1 * System: - OS: Linux - architecture: - 64bit - ELF - processor: x86_64 - python: 3.9.16 - release: 5.15.0-76-generic - version: #83~20.04.1-Ubuntu SMP Wed Jun 21 20:23:31 UTC 2023More info
Why can't i do this after Lightning transfers the data to device (
on_after_batch_transfer
) ?on_after_batch_transfer
as that may involve moving the devices on my own(step c,d) after lightning has moved automatically.What do i want ? If this is a known behaviour, would like to know what is the recommended way to do these data based operations in a GPU within my DataLoader/DataSet before the automatic device transfer by Lightning.
cc @justusschock @awaelchli