Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.53k stars 3.39k forks source link

Training process freezes on step 2 when training with manual optimization. #15395

Open dtmoodie opened 2 years ago

dtmoodie commented 2 years ago

Bug description

I'm using manual optimization to work with two datasets for multi-task learning. Due to memory usage limitations, I want to do a forward and backward pass with a batch from one dataset, then a forward and backward pass with the other dataset.

When just enabling manual optimization on one dataset, my training process freezes on step 2 if I log scalars in the on_after_backwards call with sync_dist=True for the logging call.

How to reproduce the bug

`
def on_after_backward(self) -> None:
        for task_name, task in self.model.task_decoders.items():
            if isinstance(task, LoggingTaskDecoder):
                for i, grad in enumerate(task.grads):
                    grad = grad.detach().mean()
                    name = f'task_grads/{task_name}_{i}'
                    # if sync_dist is set to False, training succeeds.
                    self.log(name, value=grad, sync_dist=True, on_epoch=True)

        super().on_after_backward()

def training_step(self, batch, batch_idx):
        if not self.automatic_optimization:
            opt = self.optimizers()
            opt.zero_grad()

        total, loss_dict= self.lossFromBatch(batch)

        if not self.automatic_optimization:
            self.manual_backward(total)
            opt.step()
            sch = self.lr_schedulers()
            sch.step()
        return dict(loss=total * self.batch_size, img=img)
`

`
def configure_optimizers(self):
        optimizer = = SGD(self.parameters(), ...)
        num_steps = self.steps_per_epoch
        scheduler = lr_scheduler.OneCycleLR(optimizer=optimizer,
                                            steps_per_epoch=num_steps,
                                            epochs=self.trainer.max_epochs,
                                            **self.hyp['scheduler'])
        scheduler = {"scheduler": scheduler, "interval": "step"}
        return [optimizer], [scheduler]
`

Error messages and logs

When this occurs, everything freezes with GPUs at 100% utilization.  No error messages or logs are available.  Attempting to attach the debugger sometimes works and I was able to find that one process was frozen on a sync_ddp within the self.log call.

Environment


* CUDA:
        - GPU:
                - NVIDIA TITAN RTX
                - NVIDIA TITAN RTX
                - NVIDIA GeForce GTX 1080 Ti
        - available:         True
        - version:           11.7
* Lightning:
        - pytorch-lightning: 1.7.0
        - pytorch-quantization: 2.1.2
        - torch:             1.13.0a0+340c412
        - torch-tensorrt:    1.1.0a0
        - torchmetrics:      0.9.3
        - torchtext:         0.13.0a0
        - torchvision:       0.13.0a0
* Packages:
        - absl-py:           1.1.0
        - actionlib:         1.13.2
        - aiohttp:           3.8.1
        - aiosignal:         1.2.0
        - alabaster:         0.7.12
        - angles:            1.9.13
        - apex:              0.1
        - appdirs:           1.4.4
        - argon2-cffi:       21.3.0
        - argon2-cffi-bindings: 21.2.0
        - asttokens:         2.0.5
        - async-timeout:     4.0.2
        - attrs:             21.4.0
        - audioread:         2.1.9
        - babel:             2.10.1
        - backcall:          0.2.0
        - backports.functools-lru-cache: 1.6.4
        - beautifulsoup4:    4.11.1
        - bleach:            5.0.0
        - blis:              0.7.7
        - bondpy:            1.8.6
        - brotlipy:          0.7.0
        - cachetools:        5.2.0
        - camera-calibration-parsers: 1.12.0
        - catalogue:         2.0.6
        - catkin:            0.8.10
        - catkin-pkg:        0.5.2
        - catkin-tools:      0.9.0
        - certifi:           2022.5.18.1
        - cffi:              1.15.0
        - chardet:           4.0.0
        - charset-normalizer: 2.0.12
        - click:             8.0.4
        - cloudpickle:       2.1.0
        - codecov:           2.1.12
        - colorama:          0.4.4
        - comet-ml:          3.31.15
        - conda:             4.13.0
        - conda-build:       3.21.9
        - conda-package-handling: 1.8.1
        - configobj:         5.0.6
        - controller-manager: 0.19.5
        - controller-manager-msgs: 0.19.5
        - coverage:          6.4.1
        - cryptography:      37.0.2
        - cuda-python:       11.7.0
        - cudf:              22.4.0a0+306.g0cb75a4913
        - cugraph:           22.4.0a0+102.g4106a188
        - cuml:              22.4.0a0+108.g2be11269d
        - cupy-cuda115:      9.6.0
        - cv-bridge:         1.15.0
        - cycler:            0.11.0
        - cymem:             2.0.6
        - cython:            0.29.30
        - dask:              2022.3.0
        - dask-cuda:         22.4.0
        - dask-cudf:         22.4.0a0+306.g0cb75a4913
        - dataclasses:       0.8
        - debugpy:           1.6.0
        - decorator:         5.1.1
        - defusedxml:        0.7.1
        - descartes:         1.1.0
        - diagnostic-analysis: 1.11.0
        - diagnostic-common-diagnostics: 1.11.0
        - diagnostic-updater: 1.11.0
        - distributed:       2022.3.0
        - distro:            1.7.0
        - docutils:          0.17.1
        - dulwich:           0.20.46
        - dynamic-reconfigure: 1.7.3
        - empy:              3.3.4
        - entrypoints:       0.4
        - everett:           3.0.0
        - executing:         0.8.3
        - expecttest:        0.1.3
        - fastjsonschema:    2.15.3
        - fastrlock:         0.8
        - filelock:          3.7.1
        - fire:              0.4.0
        - flask:             2.1.2
        - fonttools:         4.33.3
        - frozenlist:        1.3.0
        - fsspec:            2022.5.0
        - future:            0.18.2
        - gencpp:            0.6.5
        - geneus:            3.0.0
        - genlisp:           0.4.18
        - genmsg:            0.5.16
        - gennodejs:         2.0.2
        - genpy:             0.6.15
        - glob2:             0.7
        - google-auth:       2.7.0
        - google-auth-oauthlib: 0.4.6
        - graphsurgeon:      0.4.5
        - grpcio:            1.46.3
        - heapdict:          1.0.1
        - hypothesis:        4.50.8
        - idna:              3.3
        - image-geometry:    1.15.0
        - imagesize:         1.3.0
        - importlib-metadata: 4.11.4
        - importlib-resources: 5.7.1
        - iniconfig:         1.1.1
        - interactive-markers: 1.12.0
        - ipykernel:         6.14.0
        - ipython:           8.4.0
        - ipython-genutils:  0.2.0
        - ipywidgets:        8.0.2
        - itsdangerous:      2.1.2
        - jedi:              0.18.1
        - jinja2:            3.1.2
        - joblib:            1.1.0
        - joint-state-publisher: 1.15.1
        - joint-state-publisher-gui: 1.15.1
        - json5:             0.9.8
        - jsonschema:        4.6.0
        - jupyter:           1.0.0
        - jupyter-client:    7.3.4
        - jupyter-console:   6.4.4
        - jupyter-core:      4.10.0
        - jupyter-tensorboard: 0.2.0
        - jupyterlab:        2.3.2
        - jupyterlab-pygments: 0.2.2
        - jupyterlab-server: 1.2.0
        - jupyterlab-widgets: 3.0.3
        - jupytext:          1.13.8
        - kiwisolver:        1.4.3
        - langcodes:         3.3.0
        - libarchive-c:      4.0
        - librosa:           0.8.1
        - llvmlite:          0.36.0
        - lmdb:              1.3.0
        - locket:            1.0.0
        - markdown:          3.3.7
        - markdown-it-py:    2.1.0
        - markupsafe:        2.1.1
        - matplotlib:        3.5.2
        - matplotlib-inline: 0.1.3
        - mdit-py-plugins:   0.3.0
        - mdurl:             0.1.1
        - message-filters:   1.15.14
        - mish-cuda:         0.0.3
        - mistune:           0.8.4
        - mock:              4.0.3
        - msgpack:           1.0.4
        - multidict:         6.0.2
        - murmurhash:        1.0.7
        - nbclient:          0.6.4
        - nbconvert:         6.5.0
        - nbformat:          5.4.0
        - nest-asyncio:      1.5.5
        - networkx:          2.6.3
        - nltk:              3.7
        - notebook:          6.4.10
        - numba:             0.53.1
        - numpy:             1.22.4
        - nuscenes-devkit:   1.1.9
        - nvidia-dali-cuda110: 1.14.0
        - nvidia-pyindex:    1.0.9
        - nvtx:              0.2.5
        - oauthlib:          3.2.0
        - onnx:              1.11.0
        - opencv-python:     4.6.0.66
        - osrf-pycommon:     2.0.2
        - packaging:         21.3
        - pandas:            1.3.5
        - pandocfilters:     1.5.0
        - parameterized:     0.8.1
        - parso:             0.8.3
        - partd:             1.2.0
        - pathy:             0.6.1
        - pexpect:           4.8.0
        - pickleshare:       0.7.5
        - pillow:            9.0.1
        - pip:               21.2.4
        - pkginfo:           1.8.3
        - pluggy:            1.0.0
        - polygraphy:        0.33.0
        - pooch:             1.6.0
        - preshed:           3.0.6
        - prettytable:       3.3.0
        - prometheus-client: 0.14.1
        - prompt-toolkit:    3.0.29
        - protobuf:          3.19.4
        - psutil:            5.9.1
        - ptyprocess:        0.7.0
        - pure-eval:         0.2.2
        - py:                1.11.0
        - pyarrow:           6.0.1
        - pyasn1:            0.4.8
        - pyasn1-modules:    0.2.8
        - pybind11:          2.9.2
        - pycocotools:       2.0.5
        - pycosat:           0.6.3
        - pycparser:         2.21
        - pydantic:          1.8.2
        - pydeprecate:       0.3.2
        - pydot:             1.4.2
        - pygments:          2.12.0
        - pynvml:            11.4.1
        - pyopenssl:         22.0.0
        - pyparsing:         3.0.9
        - pyquaternion:      0.9.9
        - pyrsistent:        0.18.1
        - pysocks:           1.7.1
        - pytest:            6.2.5
        - pytest-cov:        3.0.0
        - pytest-pythonpath: 0.7.4
        - python-dateutil:   2.8.2
        - python-hostlist:   1.21
        - python-nvd3:       0.15.0
        - python-qt-binding: 0.4.4
        - python-slugify:    6.1.2
        - pytorch-lightning: 1.7.0
        - pytorch-quantization: 2.1.2
        - pytz:              2022.1
        - pyyaml:            6.0
        - pyzmq:             23.1.0
        - qt-dotgraph:       0.4.2
        - qt-gui:            0.4.2
        - qt-gui-cpp:        0.4.2
        - qt-gui-py-common:  0.4.2
        - qtconsole:         5.3.2
        - qtpy:              2.2.1
        - raft:              22.4.0a0+113.gf5d2627
        - regex:             2022.6.2
        - requests:          2.27.1
        - requests-oauthlib: 1.3.1
        - requests-toolbelt: 0.10.0
        - resampy:           0.2.2
        - revtok:            0.0.3
        - rmm:               22.4.0a0+50.gf82d458
        - rosbag:            1.15.14
        - rosboost-cfg:      1.15.8
        - rosclean:          1.15.8
        - roscreate:         1.15.8
        - rosgraph:          1.15.14
        - roslaunch:         1.15.14
        - roslib:            1.15.8
        - roslint:           0.12.0
        - roslz4:            1.15.14
        - rosmake:           1.15.8
        - rosmaster:         1.15.14
        - rosmsg:            1.15.14
        - rosnode:           1.15.14
        - rosparam:          1.15.14
        - rospkg:            1.4.0
        - rospy:             1.15.14
        - rosservice:        1.15.14
        - rostest:           1.15.14
        - rostopic:          1.15.14
        - rosunit:           1.15.8
        - roswtf:            1.15.14
        - rqt-action:        0.4.9
        - rqt-bag:           0.5.1
        - rqt-bag-plugins:   0.5.1
        - rqt-console:       0.4.11
        - rqt-dep:           0.4.12
        - rqt-graph:         0.4.14
        - rqt-gui:           0.5.3
        - rqt-gui-py:        0.5.3
        - rqt-image-view:    0.4.16
        - rqt-launch:        0.4.9
        - rqt-logger-level:  0.4.11
        - rqt-moveit:        0.5.10
        - rqt-msg:           0.4.10
        - rqt-nav-view:      0.5.7
        - rqt-plot:          0.4.13
        - rqt-pose-view:     0.5.11
        - rqt-publisher:     0.4.10
        - rqt-py-common:     0.5.3
        - rqt-py-console:    0.4.10
        - rqt-reconfigure:   0.5.5
        - rqt-robot-dashboard: 0.5.8
        - rqt-robot-monitor: 0.5.14
        - rqt-robot-steering: 0.5.12
        - rqt-runtime-monitor: 0.5.9
        - rqt-service-caller: 0.4.10
        - rqt-shell:         0.4.11
        - rqt-srv:           0.4.9
        - rqt-tf-tree:       0.6.3
        - rqt-top:           0.4.10
        - rqt-topic:         0.4.13
        - rqt-web:           0.4.10
        - rsa:               4.8
        - ruamel-yaml-conda: 0.15.80
        - sacremoses:        0.0.53
        - scikit-learn:      0.24.2
        - scipy:             1.6.3
        - semantic-version:  2.10.0
        - send2trash:        1.8.0
        - sensor-msgs:       1.13.1
        - sentry-sdk:        1.10.1
        - setuptools:        58.0.0
        - shapely:           1.8.5.post1
        - shellingham:       1.4.0
        - six:               1.16.0
        - smach:             2.5.0
        - smach-ros:         2.5.0
        - smart-open:        5.2.1
        - smclib:            1.8.6
        - snowballstemmer:   2.2.0
        - sortedcontainers:  2.4.0
        - soundfile:         0.10.3.post1
        - soupsieve:         2.3.1
        - spacy:             3.3.1
        - spacy-legacy:      3.0.9
        - spacy-loggers:     1.0.2
        - sphinx:            5.0.1
        - sphinx-glpi-theme: 0.3
        - sphinx-rtd-theme:  1.0.0
        - sphinxcontrib-applehelp: 1.0.2
        - sphinxcontrib-devhelp: 1.0.2
        - sphinxcontrib-htmlhelp: 2.0.0
        - sphinxcontrib-jsmath: 1.0.1
        - sphinxcontrib-qthelp: 1.0.3
        - sphinxcontrib-serializinghtml: 1.1.5
        - srsly:             2.4.3
        - stack-data:        0.2.0
        - tabulate:          0.8.9
        - tblib:             1.7.0
        - tensorboard:       2.9.1
        - tensorboard-data-server: 0.6.1
        - tensorboard-plugin-wit: 1.8.1
        - tensorrt:          8.2.5.1
        - termcolor:         2.0.1
        - terminado:         0.15.0
        - text-unidecode:    1.3
        - tf:                1.13.2
        - tf-conversions:    1.13.2
        - tf2-geometry-msgs: 0.7.5
        - tf2-kdl:           0.7.5
        - tf2-py:            0.7.5
        - tf2-ros:           0.7.5
        - thinc:             8.0.17
        - threadpoolctl:     3.1.0
        - timm:              0.6.7
        - tinycss2:          1.1.1
        - toml:              0.10.2
        - tomli:             2.0.1
        - toolz:             0.11.2
        - topic-tools:       1.15.14
        - torch:             1.13.0a0+340c412
        - torch-tensorrt:    1.1.0a0
        - torchmetrics:      0.9.3
        - torchtext:         0.13.0a0
        - torchvision:       0.13.0a0
        - tornado:           6.1
        - tqdm:              4.64.0
        - traitlets:         5.2.2.post1
        - treelite:          2.3.0
        - treelite-runtime:  2.3.0
        - typer:             0.4.1
        - typing-extensions: 4.2.0
        - ucx-py:            0.25.0a0+13.ga16f8a2
        - uff:               0.6.9
        - urllib3:           1.26.12
        - wasabi:            0.9.1
        - wcwidth:           0.2.5
        - webencodings:      0.5.1
        - websocket-client:  1.3.3
        - werkzeug:          2.1.2
        - wheel:             0.37.1
        - widgetsnbextension: 4.0.3
        - wrapt:             1.14.1
        - wurlitzer:         3.0.2
        - xacro:             1.14.13
        - xgboost:           1.5.2
        - yarl:              1.8.1
        - zict:              2.2.0
        - zipp:              3.8.0
* System:
        - OS:                Linux
        - architecture:
                - 64bit
                - ELF
        - processor:         x86_64
        - python:            3.8.13
        - version:           #147~18.04.1-Ubuntu SMP Sat Oct 15 13:10:18 UTC 2022

More info

No response

cc @awaelchli @rohitgr7 @akihironitta @carmocca @edward-io @ananthsub @Blaizzy

carmocca commented 2 years ago

A rank is hanging. Can you create a reproducible snippet by adapting this script? https://github.com/Lightning-AI/lightning/blob/master/examples/pl_bug_report/bug_report_model.py

Also, how many devices are you using? I see your machine has 2 different GPU types. Does it happen if you only use the RTXs?

stale[bot] commented 1 year ago

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!