CompVis / latent-diffusion

High-Resolution Image Synthesis with Latent Diffusion Models
MIT License
11.89k stars 1.54k forks source link

torch.cuda.OutOfMemoryError: CUDA out of memory. in PTL 2.2.2 #360

Closed code-ishwar closed 7 months ago

code-ishwar commented 7 months ago

Hi,

I am trying to upgrade the code for PTL version 2.2.2. I made some changes to suite the code as below.


#trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) # OLD trainer
 trainer = Trainer(  devices="auto")  #(trainer_opt, **trainer_kwargs) # new trainer

The program crashes even when run without GPU or with GPU. I am running this on A100 with 4 GPUs.

Can you suggest if the issue is in code or machine? TIA

Error

Traceback (most recent call last): File "main.py", line 723, in trainer.fit(model, data) File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit call._call_and_handle_interrupt( File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt return trainer_fn(*args, *kwargs) File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl self._run(model, ckpt_path=ckpt_path) File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 963, in _run self.strategy.setup(self) File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 155, in setup self.model_to_device() File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/pytorch_lightning/strategies/single_device.py", line 79, in model_to_device self.model.to(self.root_device) File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/lightning_fabric/utilities/device_dtype_mixin.py", line 55, in to return super().to(args, **kwargs) File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1160, in to return self._apply(convert) File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 810, in _apply module._apply(fn) File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 857, in _apply self._buffers[key] = fn(buf) File "/home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1158, in convert return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking) torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 39.44 GiB of which 15.06 MiB is free. Process 1761164 has 27.62 GiB memory in use. Process 2643215 has 8.48 GiB memory in use. Including non-PyTorch memory, this process has 3.08 GiB memory in use. Of the allocated memory 2.04 GiB is allocated by PyTorch, and 507.50 KiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Library Versions

absl-py==2.1.0 accelerate==0.26.1 aiohttp==3.9.1 aiosignal==1.3.1 albumentations==0.4.3 altair==5.3.0 annotated-types==0.6.0 antlr4-python3-runtime==4.8 anyio==4.2.0 archspec @ file:///home/conda/feedstock_root/build_artifacts/archspec_1708969572489/work array-record==0.4.0 astunparse==1.6.3 async-timeout==4.0.3 attrs==23.2.0 Babel==2.13.1 blinker==1.7.0 boltons @ file:///home/conda/feedstock_root/build_artifacts/boltons_1711936407380/work boto==2.49.0 Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1695989787169/work cached-property==1.5.2 cachetools==5.3.3 certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1696001773319/work charset-normalizer==3.3.2 chex==0.1.7 click==8.1.7 -e git+https://github.com/openai/CLIP.git@a1d071733d7111c9c014f024669f959182114e33#egg=clip clu==0.0.10 colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work comet-ml==3.12.2 conda @ file:///home/conda/feedstock_root/build_artifacts/conda_1711445837045/work conda-libmamba-solver @ file:///home/conda/feedstock_root/build_artifacts/conda-libmamba-solver_1706566000184/work/src conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1691048088238/work conda_package_streaming @ file:///home/conda/feedstock_root/build_artifacts/conda-package-streaming_1691009212940/work configobj==5.0.8 contextlib2==21.6.0 dataclasses-json==0.6.3 distro==1.9.0 dm-tree==0.1.8 docstring-parser==0.15 dulwich==0.21.7 editdistance==0.6.2 einops==0.7.0 etils==1.3.0 everett==3.3.0 exceptiongroup==1.2.0 faiss-cpu==1.7.4 fastapi==0.109.0 filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1712686151958/work filetype==1.2.0 flatbuffers==23.5.26 flax==0.7.2 frozenlist==1.4.1 fsspec==2023.12.1 ftfy==6.2.0 future==1.0.0 gast==0.4.0 gevent==23.9.1 gin-config==0.5.0 gitdb==4.0.11 GitPython==3.1.41 gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1666808683138/work google-api-core==2.15.0 google-api-python-client==2.110.0 google-auth==2.29.0 google-auth-httplib2==0.1.1 google-auth-oauthlib==1.0.0 google-cloud-core==2.4.1 google-cloud-storage==2.13.0 google-compute-engine==2.8.13 google-crc32c==1.5.0 google-pasta==0.2.0 google-resumable-media==2.6.0 googleapis-common-protos==1.62.0 greenlet==3.0.2 grpcio==1.62.1 h11==0.14.0 h5py==3.10.0 httpcore==1.0.2 httplib2==0.22.0 httpx==0.26.0 httpx-sse==0.4.0 huggingface-hub==0.19.4 idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1713279365350/work imageio==2.9.0 imageio-ffmpeg==0.4.2 imgaug==0.2.6 immutabledict==4.0.0 importlib-resources==6.1.1 importlib_metadata==7.1.0 jax==0.4.13 jaxlib==0.4.13 Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1704966972576/work joblib==1.4.0 jsonpatch==1.33 jsonpointer @ file:///home/conda/feedstock_root/build_artifacts/jsonpointer_1695397259927/work jsonschema==4.21.1 jsonschema-specifications==2023.12.1 keras==2.13.1 kornia==0.6.4 langchain==0.1.1 langchain-cli==0.0.20 langchain-community==0.0.13 langchain-core==0.1.13 langchain-openai==0.0.3 langserve==0.0.39 langsmith==0.0.83 -e git+ssh://git@github.com/Cyberium-Inc/latenet_sanskrit_diffusion.git@42d0ebea80d67c4a9a2200a569562c48e820a6ef#egg=latent_diffusion lazy_loader==0.4 libclang==16.0.6 libmambapy @ file:///home/conda/feedstock_root/build_artifacts/mamba-split_1711394305528/work/libmambapy lightning-utilities==0.11.2 lxml==4.9.3 Markdown==3.6 markdown-it-py==3.0.0 MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1706899923320/work marshmallow==3.20.2 mdurl==0.1.2 menuinst @ file:///home/conda/feedstock_root/build_artifacts/menuinst_1705068268702/work mesh-tensorflow==0.1.21 mkl-service==2.4.1 mkl_fft==1.3.8 mkl_random @ file:///home/conda/feedstock_root/build_artifacts/mkl_random_1707959978248/work ml-collections==0.1.1 ml-dtypes==0.2.0 mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work msgpack==1.0.7 multidict==6.0.4 mypy-extensions==1.0.0 nest-asyncio==1.5.8 networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1680692919326/work numpy @ file:///work/mkl/numpy_and_numpy_base_1682953417311/work nvidia-cublas-cu12==12.1.3.1 nvidia-cuda-cupti-cu12==12.1.105 nvidia-cuda-nvrtc-cu12==12.1.105 nvidia-cuda-runtime-cu12==12.1.105 nvidia-cudnn-cu12==8.9.2.26 nvidia-cufft-cu12==11.0.2.54 nvidia-curand-cu12==10.3.2.106 nvidia-cusolver-cu12==11.4.5.107 nvidia-cusparse-cu12==12.1.0.106 nvidia-ml-py3==7.352.0 nvidia-nccl-cu12==2.19.3 nvidia-nvjitlink-cu12==12.4.127 nvidia-nvtx-cu12==12.1.105 oauth2client==4.1.3 oauthlib==3.2.2 omegaconf==2.1.1 openai==1.8.0 opencv-python==4.1.2.30 opencv-python-headless==4.9.0.80 opt-einsum==3.3.0 optax==0.1.7 orbax-checkpoint==0.2.3 orjson==3.9.12 packaging==21.3 pandas==2.0.3 pillow @ file:///croot/pillow_1707233021655/work pkgutil_resolve_name==1.3.10 platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1706713388748/work pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1706116770704/work portalocker==2.8.2 promise==2.3 protobuf==4.25.3 psutil==5.9.6 pudb==2019.2 pyarrow==15.0.2 pyasn1==0.6.0 pyasn1_modules==0.4.0 pycosat @ file:///home/conda/feedstock_root/build_artifacts/pycosat_1696355775111/work pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1711811537435/work pydantic==2.5.3 pydantic_core==2.14.6 pydeck==0.8.1b0 pyDeprecate==0.3.1 pyglove==0.4.3 Pygments==2.17.2 pyparsing==3.1.1 PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1661604839144/work python-dateutil==2.9.0.post0 python-dotenv==0.18.0 pytorch-lightning==2.2.2 pytz==2024.1 PyWavelets==1.4.1 PyYAML==6.0.1 referencing==0.34.0 regex==2024.4.16 requests==2.31.0 requests-oauthlib==2.0.0 requests-toolbelt==1.0.0 rich==13.7.0 rouge-score==0.1.2 rpds-py==0.18.0 rsa==4.9 ruamel.yaml @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml_1707298136555/work ruamel.yaml.clib @ file:///home/conda/feedstock_root/build_artifacts/ruamel.yaml.clib_1707314498139/work sacrebleu==2.3.3 sacremoses==0.1.1 safetensors==0.4.1 scikit-image==0.20.0 scikit-learn==1.3.2 scipy==1.9.1 semantic-version==2.10.0 sentence-transformers==2.2.2 sentencepiece==0.1.99 seqio-nightly==0.0.17.dev20231209 shellingham==1.5.4 six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work smmap==5.0.1 sniffio==1.3.0 SQLAlchemy==2.0.25 sse-starlette==1.8.2 starlette==0.35.1 streamlit==1.33.0 sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1684180540116/work t5==0.9.4 tabulate==0.9.0

Editable install with no version control (taming-transformers==0.0.1)

-e /home/user/miniforge-pypy3/envs/ldm2/lib/python3.8/site-packages taming-transformers-rom1504==0.0.6 tenacity==8.2.3 tensorboard==2.13.0 tensorboard-data-server==0.7.2 tensorflow==2.13.1 tensorflow-datasets==4.9.2 tensorflow-estimator==2.13.0 tensorflow-hub==0.15.0 tensorflow-io-gcs-filesystem==0.34.0 tensorflow-metadata==1.14.0 tensorflow-text==2.13.0 tensorstore==0.1.45 termcolor==2.4.0 test-tube==0.7.5 tfds-nightly==4.9.2.dev202308090034 threadpoolctl==3.2.0 tifffile==2023.7.10 tiktoken==0.5.2 tokenizers==0.19.1 toml==0.10.2 tomlkit==0.12.3 toolz==0.12.0 torch @ file:///home/conda/feedstock_root/build_artifacts/libtorch_1711323198313/work torch-fidelity==0.3.0 torchmetrics==0.7.3 torchvision @ file:///home/conda/feedstock_root/build_artifacts/torchvision-split_1619073604192/work tornado==6.4 tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1707598593068/work transformers==4.40.0 triton==2.2.0 typer==0.9.0 typing-inspect==0.9.0 typing_extensions==4.9.0 tzdata==2024.1 uritemplate==4.1.1 urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1708239446578/work urwid==2.6.10 uvicorn==0.23.2 watchdog==4.0.0 wcwidth==0.2.13 websocket-client==1.7.0 Werkzeug==3.0.2 wrapt==1.16.0 wurlitzer==3.0.3 yarl==1.9.4 zipp==3.18.1 zope.event==5.0 zstandard==0.22.0

code-ishwar commented 7 months ago

I changed this line, and it works fine with CPU. So it seems like batch size issue for GPU.

trainer = Trainer( devices="auto", accelerator="cpu")