rwth-i6 / returnn

The RWTH extensible training framework for universal recurrent neural networks
http://returnn.readthedocs.io/
Other
349 stars 130 forks source link

PyTorch CUDA OOM in distributed training #1482

Open albertz opened 10 months ago

albertz commented 10 months ago
RETURNN starting up, version 1.20231230.164342+git.f353135e, date/time 2023-12-31-13-21-05 (UTC+0000), pid 2003528, cwd /work/asr4/zeyer/setups-data/comb
ined/2021-05-31/work/i6_core/returnn/training/ReturnnTrainingJob.lmbYlKeoU6kT/work, Python /work/tools/users/zeyer/py-envs/py3.11-torch2.1/bin/python3.11
RETURNN command line options: ['/u/zeyer/setups/combined/2021-05-31/work/i6_core/returnn/training/ReturnnTrainingJob.lmbYlKeoU6kT/output/returnn.config']
...
Torch: Hostname cn-236, pid 2003531, using GPU 3.
CUDA_VISIBLE_DEVICES is set to '0,1,2,3'.
Available CUDA devices:
  1/4: cuda:0
  1/4: cuda:0
       name: NVIDIA GeForce GTX 1080 Ti
       total_memory: 10.9GB
       name: NVIDIA GeForce GTX 1080 Ti
       capability: 6.1
       device_index: 0
       total_memory: 10.9GB
  2/4: cuda:1
       name: NVIDIA GeForce GTX 1080 Ti
       capability: 6.1
       total_memory: 10.9GB
       capability: 6.1
       device_index: 0
       device_index: 1
  2/4: cuda:1
  3/4: cuda:2
       name: NVIDIA GeForce GTX 1080 Ti
       name: NVIDIA GeForce GTX 1080 Ti
       total_memory: 10.9GB
       capability: 6.1
       total_memory: 10.9GB
       device_index: 2
  4/4: cuda:3
       capability: 6.1
       name: NVIDIA GeForce GTX 1080 Ti
       device_index: 1
       total_memory: 10.9GB
       capability: 6.1
  3/4: cuda:2
       device_index: 3
       name: NVIDIA GeForce GTX 1080 Ti
       total_memory: 10.9GB
       capability: 6.1
       device_index: 2
  4/4: cuda:3
       name: NVIDIA GeForce GTX 1080 Ti
       total_memory: 10.9GB
       capability: 6.1
       device_index: 3
PyTorch: 2.1.0+cu121 (7bcf7da3a268b435777fe87c7794c382f444e86d) (<site-package> in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch)
...
ep 1 train, step 94, acc 0.005, loss 8.755, loss_att 9.005, loss_ctc 8.171, total 8.755, mem_usage:cuda:1 7.4GB
ep 1 train, step 94, acc 0.003, loss 8.599, loss_att 8.744, loss_ctc 8.261, total 8.599, mem_usage:cuda:0 7.4GB
ep 1 train, step 94, acc 0.003, loss 8.492, loss_att 8.639, loss_ctc 8.152, total 8.492, mem_usage:cuda:2 7.3GB
ep 1 train, step 95, acc 0.004, loss 8.657, loss_att 8.896, loss_ctc 8.100, total 8.657, mem_usage:cuda:1 7.4GB
ep 1 train, step 95, acc 0.007, loss 9.245, loss_att 9.607, loss_ctc 8.400, total 9.245, mem_usage:cuda:3 7.4GB
ep 1 train, step 95, acc 0.003, loss 8.452, loss_att 8.596, loss_ctc 8.116, total 8.452, mem_usage:cuda:0 7.4GB
ep 1 train, step 95, acc 0.004, loss 8.648, loss_att 8.824, loss_ctc 8.238, total 8.648, mem_usage:cuda:2 7.3GB
MEMORY: sub proc watch memory(2003717) increased RSS: rss=52.4MB pss=30.8MB uss=30.6MB shared=21.7MB
ep 1 train, step 96, acc 0.003, loss 8.667, loss_att 8.789, loss_ctc 8.382, total 8.667, mem_usage:cuda:3 7.4GB
ep 1 train, step 96, acc 0.001, loss 8.325, loss_att 8.352, loss_ctc 8.261, total 8.325, mem_usage:cuda:1 7.4GB
ep 1 train, step 96, acc 0.005, loss 8.874, loss_att 9.176, loss_ctc 8.168, total 8.874, mem_usage:cuda:0 7.4GB
ep 1 train, step 96, acc 0.001, loss 8.337, loss_att 8.333, loss_ctc 8.346, total 8.337, mem_usage:cuda:2 7.3GB
MEMORY: total (main 2003529, 2023-12-31, 13:25:48, 20 procs): pss=8.4GB uss=6.2GB
ep 1 train, step 97, acc 0.003, loss 8.599, loss_att 8.779, loss_ctc 8.178, total 8.599, mem_usage:cuda:0 7.4GB
ep 1 train, step 97, acc 0.003, loss 8.641, loss_att 8.805, loss_ctc 8.257, total 8.641, mem_usage:cuda:3 7.4GB
ep 1 train, step 97, acc 0.004, loss 8.589, loss_att 8.768, loss_ctc 8.170, total 8.589, mem_usage:cuda:1 7.4GB
ep 1 train, step 97, acc 0.003, loss 8.568, loss_att 8.694, loss_ctc 8.272, total 8.568, mem_usage:cuda:2 7.3GB
ep 1 train, step 98, acc 0.001, loss 8.286, loss_att 8.344, loss_ctc 8.151, total 8.286, mem_usage:cuda:0 7.4GB
ep 1 train, step 98, acc 0.005, loss 8.941, loss_att 9.186, loss_ctc 8.369, total 8.941, mem_usage:cuda:3 7.4GB
ep 1 train, step 98, acc 0.002, loss 8.423, loss_att 8.466, loss_ctc 8.324, total 8.423, mem_usage:cuda:1 7.4GB
ep 1 train, step 98, acc 0.004, loss 8.574, loss_att 8.728, loss_ctc 8.216, total 8.574, mem_usage:cuda:2 7.3GB
Unhandled exception <class 'RuntimeError'> in thread <_MainThread(MainThread, started 139662461636608)>, proc 2003528.
...
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/distributed.py", line 98, in DistributedContext.step_after_param_update
    line: _sync_params_avg(module=module)
    locals:
      _sync_params_avg = <global> <function _sync_params_avg at 0x7f04dc6ba520>
      module = <local> ESPnetASRModel(
                         (frontend): DefaultFrontend(
                           (stft): Stft(n_fft=512, win_length=512, hop_length=160, center=True, normalized=False, onesided=True)
                           (frontend): Frontend()
                           (logmel): LogMel(sr=16000, n_fft=512, n_mels=80, fmin=0, fmax=8000.0, htk=False)
                         )
                         (specaug): SpecAug(
                           (t...
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/distributed.py", line 152, in _sync_params_avg
    line: dist.all_reduce(param.data, op=reduce_op)
    locals:
      dist = <local> <module 'torch.distributed' from '/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/distributed/__i
nit__.py'>
      dist.all_reduce = <local> <function all_reduce at 0x7f05078fcae0>
      param = <local> Parameter containing:
                      Parameter[512, 1, 3, 3] n=4608 (18Kb) x∈[-0.333, 0.333] μ=0.000 σ=0.195 grad cuda:0
      param.data = <local> tensor[512, 1, 3, 3] n=4608 (18Kb) x∈[-0.333, 0.333] μ=0.000 σ=0.195 cuda:0
      op = <not found>
      reduce_op = <local> <RedOpType.AVG: 1>
  File "/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 47, in send
    line: return func(*args, **kwargs)
    locals:
      func = <local> <function all_reduce at 0x7f05078fca40>
      args = <local> (tensor[512, 1, 3, 3] n=4608 (18Kb) x∈[-0.333, 0.333] μ=0.000 σ=0.195 cuda:0,)
      kwargs = <local> {'op': <RedOpType.AVG: 1>}
  File "/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2050, in all_reduce
    line: work = group.allreduce([tensor], opts)
    locals:
      work = <not found>
      group = <local> <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f04de15dc70>
      group.allreduce = <local> <bound method PyCapsule.allreduce of <torch.distributed.distributed_c10d.ProcessGroup object at 0x7f04de15dc70>>
      tensor = <local> tensor[512, 1, 3, 3] n=4608 (18Kb) x∈[-0.333, 0.333] μ=0.000 σ=0.195 cuda:0
      opts = <local> <torch.distributed.distributed_c10d.AllreduceOptions object at 0x7f059dce46f0>
RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Note that RuntimeError: CUDA error: out of memory is not the usual OutOfMemoryError exception (which also provides some stats on reserved memory etc) but this comes from torch distributed and unfortunately lacks further stats.

It's a bit strange because looking at the training log before the OOM, it uses around 7.4GB (allocated, so a bit more reserved), and from the initial log, all the device memory seem to be available?

albertz commented 10 months ago

This is very deterministic, when I restart, I get the same crash exactly in the same crash, also on other nodes.

albertz commented 10 months ago

Potentially related: https://github.com/pytorch/pytorch/issues/116177 https://github.com/NVlabs/tiny-cuda-nn/issues/387 https://github.com/NVIDIA/nccl/issues/962

albertz commented 10 months ago

I get the same problem also with Gloo backend, i.e. also CUDA OOM, although then it crashes in a different way with an abort.

...
ep 1 train, step 97, acc 0.004, loss 8.624, loss_att 8.769, loss_ctc 8.285, total 8.624, mem_usage:cuda:2 8.8GB, 0.855 sec/step
ep 1 train, step 98, acc 0.007, loss 9.071, loss_att 9.377, loss_ctc 8.356, total 9.071, mem_usage:cuda:1 8.6GB, 0.797 sec/step
ep 1 train, step 98, acc 0.004, loss 8.664, loss_att 8.846, loss_ctc 8.239, total 8.664, mem_usage:cuda:3 8.5GB, 0.801 sec/step
ep 1 train, step 98, acc 0.005, loss 8.674, loss_att 8.856, loss_ctc 8.248, total 8.674, mem_usage:cuda:0 8.9GB, 0.892 sec/step
ep 1 train, step 98, acc 0.003, loss 8.459, loss_att 8.575, loss_ctc 8.190, total 8.459, mem_usage:cuda:2 8.8GB, 0.834 sec/step
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:44 (most recent call first): 
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fda36535617 in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10.so) 
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fda364f098d in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10.so) 
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7fda365f09f8 in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10_cuda.so) 
frame #3: <unknown function> + 0x1d104 (0x7fda365c0104 in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
frame #4: <unknown function> + 0x4bc384a (0x7fd9e5be384a in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libtor
ch_cpu.so)
frame #5: <unknown function> + 0x559d0a8 (0x7fd9e65bd0a8 in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libtor
ch_cpu.so)
frame #6: c10d::ProcessGroupGloo::AsyncWork::execute(c10::intrusive_ptr<c10d::ProcessGroupGloo::AsyncWork, c10::detail::intrusive_target_default_null_typ
e<c10d::ProcessGroupGloo::AsyncWork> >) + 0x3b (0x7fd9e65cbf8b in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/
libtorch_cpu.so)
frame #7: c10d::ProcessGroupGloo::runLoop(int) + 0xe9 (0x7fd9e65cc099 in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/tor
ch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xdba24 (0x7fda369dda24 in /work/tools/users/zeyer/linuxbrew/lib/gcc/11/libstdc++.so.6)
frame #9: <unknown function> + 0x8523e (0x7fda6157023e in /work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6)
frame #10: <unknown function> + 0x10617c (0x7fda615f117c in /work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007fda4f220640 (most recent call first):
  <no Python frame>

Thread 0x00007fd90f6ae640 (most recent call first):
  <no Python frame>

Thread 0x00007fd90cead640 (most recent call first):
  <no Python frame>

Thread 0x00007fd911eaf640 (most recent call first):
  <no Python frame>

Thread 0x00007fd9006ac640 (most recent call first):
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/threading.py", line 320 in wait
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/multiprocessing/queues.py", line 231 in _feed
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/threading.py", line 975 in run
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/threading.py", line 1038 in _bootstrap_inner
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/threading.py", line 995 in _bootstrap

Thread 0x00007fda614ea000 (most recent call first):
  File "/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2055 in all_reduce
  File "/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 47 in wrapper
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/distributed.py", line 160 in _sync_params_avg
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/distributed.py", line 99 in step_after_param_update
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/engine.py", line 389 in train_epoch
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/engine.py", line 239 in train
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/__main__.py", line 465 in execute_main_task
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/__main__.py", line 659 in main
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/rnn.py", line 11 in <module>

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._c
...
Signal handler: signal 6:
/var/tmp/zeyer/returnn_native/native_signal_handler/476dd6f1a7/native_signal_handler.so(signal_handler+0x4b)[0x7fda2a87320b]
/work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6(+0x3cf40)[0x7fda61527f40]
/work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6(+0x86e6f)[0x7fda61571e6f]
/work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6(raise+0x12)[0x7fda61527ea2]
/work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6(+0x3cf40)[0x7fda61527f40]
/work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6(+0x86e6f)[0x7fda61571e6f]
/work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6(raise+0x12)[0x7fda61527ea2]
/work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6(abort+0xc2)[0x7fda6151345c]
/work/tools/users/zeyer/linuxbrew/lib/gcc/11/libstdc++.so.6(+0xa586a)[0x7fda369a786a]
/work/tools/users/zeyer/linuxbrew/lib/gcc/11/libstdc++.so.6(+0xb107a)[0x7fda369b307a]
/work/tools/users/zeyer/linuxbrew/lib/gcc/11/libstdc++.so.6(+0xb10e5)[0x7fda369b30e5]
/work/tools/users/zeyer/linuxbrew/lib/gcc/11/libstdc++.so.6(+0xb1338)[0x7fda369b3338]
/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10.so(_ZN3c106detail14torchCheckFailEPKcS2_jRKSs+0x94)[0x7fda3
64f09bd]
/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10_cuda.so(_ZN3c104cuda29c10_cuda_check_implementationEiPKcS2_
ib+0x118)[0x7fda365f09f8]
/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10_cuda.so(+0x1d104)[0x7fda365c0104]
/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so(+0x4bc384a)[0x7fd9e5be384a]
/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so(+0x559d0a8)[0x7fd9e65bd0a8]
/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so(_ZN4c10d16ProcessGroupGloo9AsyncWork7executeEN3c10
13intrusive_ptrIS1_NS2_6detail34intrusive_target_default_null_typeIS1_EEEE+0x3b)[0x7fd9e65cbf8b]
/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libtorch_cpu.so(_ZN4c10d16ProcessGroupGloo7runLoopEi+0xe9)[0x7fd9e
65cc099] 
/work/tools/users/zeyer/linuxbrew/lib/gcc/11/libstdc++.so.6(+0xdba24)[0x7fda369dda24]
/work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6(+0x8523e)[0x7fda6157023e]
/work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6(+0x10617c)[0x7fda615f117c]
terminate called after throwing an instance of 'c10::Error'
  what():  CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:44 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f5418512617 in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/si
te-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7f54184cd98d in /work/tools/users/zeyer/py-en
vs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7f54185cd9f8 in /work/tools/users/zeyer/py-envs/p
y3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10_cuda.so)
frame #3: <unknown function> + 0x1d104 (0x7f541859d104 in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libc10_c
uda.so)
frame #4: <unknown function> + 0x4bc384a (0x7f53d37e384a in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libtor
ch_cpu.so)
frame #5: <unknown function> + 0x559d0a8 (0x7f53d41bd0a8 in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/libtor
ch_cpu.so) 
frame #6: c10d::ProcessGroupGloo::AsyncWork::execute(c10::intrusive_ptr<c10d::ProcessGroupGloo::AsyncWork, c10::detail::intrusive_target_default_null_typ
e<c10d::ProcessGroupGloo::AsyncWork> >) + 0x3b (0x7f53d41cbf8b in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/lib/
libtorch_cpu.so) 
frame #7: c10d::ProcessGroupGloo::runLoop(int) + 0xe9 (0x7f53d41cc099 in /work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/tor
ch/lib/libtorch_cpu.so)
frame #8: <unknown function> + 0xdba24 (0x7f54244b8a24 in /work/tools/users/zeyer/linuxbrew/lib/gcc/11/libstdc++.so.6)
frame #9: <unknown function> + 0x8523e (0x7f544f05123e in /work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6)
frame #10: <unknown function> + 0x10617c (0x7f544f0d217c in /work/tools/users/zeyer/linuxbrew/opt/glibc/lib/libc.so.6)

Fatal Python error: Aborted

Thread 0x00007f543add4640 (most recent call first):
  <no Python frame>

Thread 0x00007f52fd7af640 (most recent call first):
  <no Python frame>

Thread 0x00007f52f87ad640 (most recent call first):
  <no Python frame>

Thread 0x00007f52fafae640 (most recent call first):
  <no Python frame>

Thread 0x00007f52f5fac640 (most recent call first):
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/threading.py", line 320 in wait
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/multiprocessing/queues.py", line 231 in _feed
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/threading.py", line 975 in run
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/threading.py", line 1038 in _bootstrap_inner
  File "/work/tools/users/zeyer/linuxbrew/opt/python@3.11/lib/python3.11/threading.py", line 995 in _bootstrap

Thread 0x00007f544efcb000 (most recent call first):
  File "/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/distributed/distributed_c10d.py", line 2055 in all_reduce
  File "/work/tools/users/zeyer/py-envs/py3.11-torch2.1/lib/python3.11/site-packages/torch/distributed/c10d_logger.py", line 47 in wrapper
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/distributed.py", line 160 in _sync_params_avg
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/distributed.py", line 99 in step_after_param_update
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/engine.py", line 389 in train_epoch
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/torch/engine.py", line 239 in train
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/__main__.py", line 465 in execute_main_task
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/returnn/__main__.py", line 659 in main
  File "/u/zeyer/setups/combined/2021-05-31/tools/returnn/rnn.py", line 11 in <module>

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._c
...

In this case, as you see, all the workers crash in the same way.

albertz commented 10 months ago

This is very deterministic, when I restart, I get the same crash exactly in the same crash, also on other nodes.

I realized, this is using "torch_distributed": {"reduce_type": "param", "param_sync_step": 100}, and it did not yet print the log output for the current step, which is step 99, so this is exactly the first step where it performs the param sync.

albertz commented 10 months ago

One workaround is using the newly introduced torch_distributed sync_on_cpu=True option, which first moves all params to CPU, then does the sync (which would use Gloo on CPU), then moves it back to GPU.

But why does this work? What does NCCL/Gloo do different, when the param is on GPU? This is a GeForce GTX 1080, so there is no NVlink. So I was assuming it would anyway internally move it to CPU, then do the allreduce on CPU, and then back to GPU. But probably not? Maybe it copies all params to CPU, then over network to all workers, then each copy of the param to GPU, so it has num_workers times the param in memory, and then does the reduce (AVG or SUM) on GPU? This might explain it. But I was assuming that the all_reduce is somewhat more clever, maybe does it hierarchically or so, i.e. not use this naive logic, which is not the most efficient and takes so much memory?

albertz commented 10 months ago

Note, the 1080 has 10.9GB of memory, just the parameters take only 615.9MB of memory.

The all_reduce is in blocking mode (just the default), and we do this separately for each parameter. The biggest parameter might be the embedding (512 x 100025), although that is not where it crashes. In any case, even if we would have 4 times such a big parameter in memory, it should be way more than enough memory available, so this does not really explain it.

albertz commented 10 months ago

I also asked in the forums: https://discuss.pytorch.org/t/cuda-oom-in-distributed-training-without-nvlink/194704

albertz commented 1 month ago

Note, in https://github.com/pytorch/pytorch/issues/116177 (and https://github.com/NVIDIA/nccl/issues/1197), there was the hint to use NCCL_NVLS_ENABLE=0 as another workaround for this. (I did not try this yet.)

albertz commented 1 month ago

With more NCCL debug info:

...
ep 13 train, step 97, ctc_4 4.683, ctc_8 4.594, ctc 4.674, aed_ce 5.543, aed_fer 0.817, num_seqs 9, max_size:time 241561, max_size:out-spatial 63, mem_usage:cuda:3 8.8GB, 0.888 sec/step
ep 13 train, step 98, ctc_4 4.578, ctc_8 4.435, ctc 4.499, aed_ce 5.339, aed_fer 0.823, num_seqs 12, max_size:time 156520, max_size:out-spatial 53, mem_usage:cuda:0 8.9GB, 0.812 sec/step
ep 13 train, step 98, ctc_4 4.089, ctc_8 3.892, ctc 3.930, aed_ce 5.001, aed_fer 0.776, num_seqs 10, max_size:time 220720, max_size:out-spatial 55, mem_usage:cuda:1 9.0GB, 0.882 sec/step
ep 13 train, step 98, ctc_4 4.404, ctc_8 4.170, ctc 4.206, aed_ce 5.589, aed_fer 0.831, num_seqs 9, max_size:time 243673, max_size:out-spatial 47, mem_usage:cuda:2 9.0GB, 0.872 sec/step
ep 13 train, step 98, ctc_4 3.731, ctc_8 3.489, ctc 3.528, aed_ce 5.102, aed_fer 0.806, num_seqs 9, max_size:time 241561, max_size:out-spatial 53, mem_usage:cuda:3 8.8GB, 0.872 sec/step
cn-241:3260080:3260080 [0] NCCL INFO Bootstrap : Using enp5s0:10.6.9.41<0>
cn-241:3260080:3260080 [0] NCCL INFO NET/Plugin : Plugin load (libnccl-net.so) returned 2 : libnccl-net.so: cannot open shared object file: No such file or directory
cn-241:3260080:3260080 [0] NCCL INFO NET/Plugin : No plugin found, using internal implementation
cn-241:3260080:3260080 [0] NCCL INFO cudaDriverVersion 12010
NCCL version 2.18.1+cuda12.1
cn-241:3260080:3262931 [0] NCCL INFO NET/IB : No device found.
cn-241:3260080:3262931 [0] NCCL INFO NET/Socket : Using [0]enp5s0:10.6.9.41<0>
cn-241:3260080:3262931 [0] NCCL INFO Using network Socket
DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/NCCLUtils.hpp:219, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.18.1
ncclUnhandledCudaError: Call to CUDA function failed.
Last error:
Cuda failure 'out of memory'
albertz commented 1 month ago

I think we can handle that a bit better. I think NCCL can be initialized such that it reserves the needed memory already in advance. https://github.com/pytorch/pytorch/issues/116177#issuecomment-2343822534:

One thing I recommend is to eagerly initialize nccl and then check the free GPU memory before doing a collective. To eagerly initialize nccl, simply pass device_id=torch.device("cuda:0") or whatever device index you want, to torch.distributed.init_process_group(). When doing this, nccl initialization will happen during that API call, and then NCCL should not consume additional memory on the first allreduce call.

But passing device_id is only possible in newer PyTorch version.

I also read that this is dependent on the NCCL version. Newer NCCL versions might require less memory: https://github.com/NVIDIA/nccl/issues/1197#issuecomment-1980391319:

NCCL 2.21 will reduce the NVLS memory usage significantly as we've found that NVLS memory usage was a problem for codes which were already close to using all memory. It will still use more memory than with NVLS disabled though; we're working on reducing memory usage even further in NCCL 2.22.

albertz commented 1 month ago

We might need to redesign the way we handle distributed computing in Torch a bit. Currently we do a single dist.init_process_group(backend=None) call to initialize a global process group. I think we maybe want to create explicit process groups, for CPU and CUDA.