facebookresearch / detectron2

Detectron2 is a platform for object detection, segmentation and other visual recognition tasks.
https://detectron2.readthedocs.io/en/latest/
Apache License 2.0
30.01k stars 7.41k forks source link

Unexpected behaviour when replicating FPN and ResNet using DataParallel #4019

Open jbrachat opened 2 years ago

jbrachat commented 2 years ago

Instructions To Reproduce the Issue:

Running any training script in DataParallel mode on more than 2 devices will eventually trigger self.to(device) method to be called on FPN and ResNet. Unfortunately those classes store some modules using List instead of ModuleList (see for instance FPN.lateral_convs, FPN.output_convs or ResNet.stages) and those modules are not properly moved when self.to(device) method is called (which then triggers an error when the self.forward method is called on tensor on another device...)

Expected behavior:

The expected behavior would be that all submodules involved in the forward methods of FPN and ResNet should be properly moved when .to() method is called. Using ModuleList instead of List should a good start to solve this issue.

Your environment:

platform: win-64

abseil-cpp=20210324.1=h0e60522_0 absl-py=1.0.0=pyhd8ed1ab_0 aiohttp=3.7.0=py37h4ab8f01_0 alembic=1.7.6=pyhd8ed1ab_0 antlr4-python3-runtime=4.8=pypi_0 appdirs=1.4.4=pypi_0 arrow-cpp=4.0.0=py37hb1a8454_3_cpu async-timeout=3.0.1=py_1000 attrs=21.4.0=pyhd8ed1ab_0 autopage=0.5.0=pyhd8ed1ab_0 aws-c-cal=0.5.11=he19cf47_0 aws-c-common=0.6.2=h8ffe710_0 aws-c-event-stream=0.2.7=h70e1b0c_13 aws-c-io=0.10.5=h2fe331c_0 aws-checksums=0.1.11=h1e232aa_7 aws-sdk-cpp=1.8.186=hb0612c5_3 backports=1.0=py_2 backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 black=21.4b2=pypi_0 blas=1.0=mkl blinker=1.4=py_1 brotli=1.0.9=h8ffe710_6 brotli-bin=1.0.9=h8ffe710_6 brotli-python=1.0.9=py37hf2a7229_6 brotlicffi=1.0.9.2=py37hf2a7229_1 brotlipy=0.7.0=py37hcc03f2d_1003 bzip2=1.0.8=h8ffe710_4 c-ares=1.18.1=h8ffe710_0 ca-certificates=2020.10.14=0 cachetools=5.0.0=pyhd8ed1ab_0 certifi=2020.6.20=py37_0 cffi=1.15.0=py37hd8e9650_0 chardet=3.0.4=py37hf50a25e_1008 charset-normalizer=2.0.12=pyhd8ed1ab_0 click=8.0.4=py37h03978a9_0 cliff=3.10.1=pyhd8ed1ab_0 cloudpickle=2.0.0=pypi_0 cmaes=0.8.2=pyh44b312d_0 cmd2=2.3.3=py37h03978a9_1 colorama=0.4.4=pyh9f0ad1d_0 colorlog=6.6.0=py37h03978a9_0 conllu=4.4.1=pyhd8ed1ab_0 cryptography=36.0.1=py37h65266a2_0 cudatoolkit=11.3.1=h59b6b97_2 cycler=0.11.0=pyhd8ed1ab_0 dataclasses=0.8=pyhc8e2a94_3 datasets=1.11.0=pyhd8ed1ab_0 detectron2=0.6=dev_0 dill=0.3.4=pyhd8ed1ab_0 et_xmlfile=1.0.1=py_1001 filelock=3.6.0=pyhd8ed1ab_0 fonttools=4.29.1=py37hcc03f2d_0 freetype=2.10.4=hd328e21_0 fsspec=2022.2.0=pyhd8ed1ab_0 future=0.18.2=pypi_0 fvcore=0.1.5.post20220212=pypi_0 gflags=2.2.2=ha925a31_1004 glog=0.5.0=h4797de2_0 google-auth=2.6.0=pyh6c4a22f_1 google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 greenlet=1.1.2=py37hf2a7229_1 grpc-cpp=1.37.1=h586195c_2 grpcio=1.42.0=py37hc60d5dd_0 huggingface_hub=0.1.0=pyhd8ed1ab_0 hydra-core=1.1.1=pypi_0 icu=68.2=h0e60522_0 idna=3.3=pyhd8ed1ab_0 importlib-metadata=4.11.2=py37h03978a9_0 importlib_metadata=4.11.2=hd8ed1ab_0 importlib_resources=5.4.0=pyhd8ed1ab_0 intel-openmp=2021.4.0=haa95532_3556 iopath=0.1.9=pypi_0 joblib=1.1.0=pyhd8ed1ab_0 jpeg=9d=h2bbff1b_0 kiwisolver=1.3.2=py37h8c56517_1 krb5=1.19.2=h1176d77_4 libblas=3.9.0=1_h8933c1f_netlib libbrotlicommon=1.0.9=h8ffe710_6 libbrotlidec=1.0.9=h8ffe710_6 libbrotlienc=1.0.9=h8ffe710_6 libcblas=3.9.0=5_hd5c7e75_netlib libclang=11.1.0=default_h5c34c98_1 libcurl=7.79.1=h789b8ee_1 libiconv=1.16=he774522_0 liblapack=3.9.0=5_hd5c7e75_netlib libpng=1.6.37=h2a8f88b_0 libprotobuf=3.16.0=h7755175_0 libssh2=1.10.0=h680486a_2 libthrift=0.14.1=h636ae23_2 libtiff=4.2.0=hd0e1b90_0 libutf8proc=2.7.0=hcb41399_0 libuv=1.40.0=he774522_0 libwebp=1.2.2=h2bbff1b_0 libxml2=2.9.12=hf5bbc77_0 libxslt=1.1.33=h65864e5_2 lxml=4.8.0=py37hd07aab1_0 lz4-c=1.9.3=h2bbff1b_1 m2w64-gcc-libgfortran=5.3.0=6 m2w64-gcc-libs=5.3.0=7 m2w64-gcc-libs-core=5.3.0=7 m2w64-gmp=6.1.0=2 m2w64-libwinpthread-git=5.0.0.4634.697f757=2 mako=1.1.6=pyhd8ed1ab_0 markdown=3.3.6=pyhd8ed1ab_0 markupsafe=2.1.0=py37hcc03f2d_0 matplotlib=3.5.1=py37h03978a9_0 matplotlib-base=3.5.1=py37h4a79c79_0 mkl=2021.4.0=haa95532_640 mkl-service=2.4.0=py37h2bbff1b_0 mkl_fft=1.3.1=py37h277e83a_0 mkl_random=1.2.2=py37hf11a4ad_0 msys2-conda-epoch=20160418=1 multidict=6.0.2=py37hcc03f2d_0 multiprocess=0.70.12.2=py37hcc03f2d_1 multivolumefile=0.2.3=pyhd8ed1ab_0 munkres=1.1.4=pyh9f0ad1d_0 mypy-extensions=0.4.3=pypi_0 numpy=1.21.5=py37ha4e8547_0 numpy-base=1.21.5=py37hc2deb75_0 oauthlib=3.2.0=pyhd8ed1ab_0 olefile=0.46=py37_0 omegaconf=2.1.1=pypi_0 openpyxl=3.0.9=pyhd8ed1ab_0 openssl=1.1.1m=h2bbff1b_0 optuna=2.10.0=pyhd8ed1ab_0 packaging=21.3=pyhd8ed1ab_0 pandas=1.1.3=py37ha925a31_0 parquet-cpp=1.5.1=2 pathspec=0.9.0=pypi_0 pbr=5.8.1=pyhd8ed1ab_0 pillow=8.4.0=py37hd45dc43_0 pip=21.2.4=py37haa95532_0 plotly=5.6.0=py_0 portalocker=2.4.0=pypi_0 prettytable=3.1.1=pyhd8ed1ab_0 protobuf=3.16.0=py37hf2a7229_0 py7zr=0.17.4=pyhd8ed1ab_1 pyarrow=4.0.0=py37h0b73db8_3_cpu pyasn1=0.4.8=py_0 pyasn1-modules=0.2.7=py_0 pybcj=0.5.0=py37hcc03f2d_2 pybcpy=0.0.17=pyhd8ed1ab_0 pycocotools=2.0.4=pypi_0 pycparser=2.21=pyhd8ed1ab_0 pycryptodomex=3.14.1=py37hcc03f2d_0 pydot=1.4.2=pypi_0 pyjwt=2.3.0=pyhd8ed1ab_1 pyopenssl=22.0.0=pyhd8ed1ab_0 pyparsing=3.0.7=pyhd8ed1ab_0 pyperclip=1.8.2=pyhd8ed1ab_2 pyppmd=0.17.3=py37hf2a7229_1 pyqt=5.12.3=py37h03978a9_8 pyqt-impl=5.12.3=py37hf2a7229_8 pyqt5-sip=4.19.18=py37hf2a7229_8 pyqtchart=5.12=py37hf2a7229_8 pyqtwebengine=5.12.1=py37hf2a7229_8 pyreadline=2.1=py37h03978a9_1005 pysocks=1.7.1=py37h03978a9_4 python=3.7.11=h6244533_0 python-dateutil=2.8.1=py_0 python-xxhash=3.0.0=py37hcc03f2d_0 python_abi=3.7=2_cp37m pytorch=1.10.2=py3.7_cuda11.3_cudnn8_0 pytorch-mutex=1.0=cuda pytz=2020.1=py_0 pyu2f=0.1.5=pyhd8ed1ab_0 pywin32=303=pypi_0 pyyaml=6.0=py37hcc03f2d_3 pyzstd=0.15.0=py37hcc03f2d_0 qt=5.12.9=h5909a2a_4 re2=2021.04.01=h0e60522_0 regex=2022.3.2=py37hcc03f2d_0 requests=2.27.1=pyhd8ed1ab_0 requests-oauthlib=1.3.1=pyhd8ed1ab_0 rsa=4.8=pyhd8ed1ab_0 sacremoses=0.0.46=pyhd8ed1ab_0 scikit-learn=1.0.2=py37hcabfae0_0 scipy=1.7.3=py37hb6553fb_0 seaborn=0.11.0=py_0 sentencepiece=0.1.96=py37h8c56517_0 seqeval=1.2.2=pyhd3deb0d_0 setuptools=58.0.4=py37haa95532_0 six=1.16.0=pyhd3eb1b0_1 snappy=1.1.8=ha925a31_3 sqlalchemy=1.4.31=py37hcc03f2d_0 sqlite=3.37.2=h2bbff1b_0 stevedore=3.5.0=py37h03978a9_2 tabulate=0.8.9=pypi_0 tenacity=8.0.1=py37haa95532_0 tensorboard=2.8.0=pyhd8ed1ab_1 tensorboard-data-server=0.6.0=py37h03978a9_1 tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 tensorboardx=2.5=pyhd8ed1ab_0 termcolor=1.1.0=pypi_0 texttable=1.6.4=pyhd8ed1ab_0 threadpoolctl=3.1.0=pyh8a188c0_0 tk=8.6.11=h2bbff1b_0 tokenizers=0.10.3=py37h537c2b9_1 toml=0.10.2=pypi_0 torchvision=0.11.3=py37_cu113 tornado=6.1=py37hcc03f2d_2 tqdm=4.49.0=pyh9f0ad1d_0 transformers=4.16.2=pyhd8ed1ab_0 typed-ast=1.5.2=pypi_0 typing-extensions=3.10.0.2=hd3eb1b0_0 typing_extensions=3.10.0.2=pyh06a4308_0 unicodedata2=14.0.0=py37hcc03f2d_0 urllib3=1.26.8=pyhd8ed1ab_1 vc=14.2=h21ff451_1 vs2015_runtime=14.27.29016=h5e58377_2 wcwidth=0.2.5=pyh9f0ad1d_2 werkzeug=2.0.3=pyhd8ed1ab_1 wheel=0.37.1=pyhd3eb1b0_0 win_inet_pton=1.1.0=py37h03978a9_3 wincertstore=0.2=py37haa95532_2 xxhash=0.8.0=h8ffe710_3 xz=5.2.5=h62dcd97_0 yacs=0.1.8=pypi_0 yaml=0.2.5=h8ffe710_2 yarl=1.6.0=py37h4ab8f01_0 zipp=3.7.0=pyhd8ed1ab_1 zlib=1.2.11=h8cc25b3_4 zstd=1.4.9=h19a0ad4_0

github-actions[bot] commented 2 years ago

You've chosen to report an unexpected problem or bug. Unless you already know the root cause of it, please include details about it by filling the issue template. The following information is missing: "Your Environment";

ppwwyyxx commented 2 years ago

I believe model.to(device) works correctly for FPN and ResNet. If you believe otherwise, please show code that reproduces the issue, as the issue template requested.

jbrachat commented 2 years ago

Hello @ppwwyyxx

thank you for your reply. You are right the .to() is perfectly fine. I spent some time investigating a bit further and the issue is actually coming from nn.DataParallel when replicas of FPN are created on several GPUs. Here is the code I am using on a 2 GPU (windows) machine:

import torch
import torch.nn as nn
from transformers import LayoutLMTokenizerFast, LayoutLMv2TokenizerFast, LayoutLMv2Config, \
    LayoutLMv2ForTokenClassification,

model_config = LayoutLMv2Config.from_pretrained(r"microsoft/layoutlmv2-base-uncased")
    layout_lm_v2 = LayoutLMv2ForTokenClassification.from_pretrained(r"microsoft/layoutlmv2-base-uncased", config=model_config)
    fpn_model = layout_lm_v2.layoutlmv2.visual.backbone

    device = torch.device("cuda:0")
    fpn_model.to(device)

    model_parallel = nn.DataParallel(fpn_model)
    replicas = model_parallel.replicate(model_parallel.module, [0, 1])

    second_replica = replicas[1]
    print(second_replica.lateral_convs[0].bias.device)
    print(second_replica.fpn_lateral2.bias.device)

Here is what comes out:

cuda:0
cuda:1

You can see the first con2d layers stored in the list: lateral_convs is not properly moved and remains on gpu:0. If now in fpn.py I replace:

self.lateral_convs = lateral_convs[::-1] --> self.lateral_convs = nn.ModuleList(lateral_convs[::-1])

the layer is then moved on the proper device.

PS: sorry for the unconventional way of initializing detectron2 using HuggingFace, but I guess if you initialize differently you should still observe the same behaviour.