pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.48k stars 480 forks source link

Single Node Distributed Training #3700

Open codeislife99 opened 2 years ago

codeislife99 commented 2 years ago

🐛 Bug

When PyTorch is compiled with CUDA enabled and torch_xla is compiled afterwards, then while running distributed training the following OOM error shows up along with it launching nxn processes.

https://logpaste.com/Bld77rSW

To Reproduce

  1. Compile PyTorch with CUDA enabled.
  2. Compile PyTorch XLA regularly
  3. Run GPU_NUM_DEVICES=8 python3 test_train_mp_mnist_amp.py --fake_data from xla/test folder

Expected behavior

The script should launch 8 processes but launches 64 of them, 8 on each device, and crashes with an OOM error whose details can be seen in this log paste.

https://logpaste.com/Bld77rSW

Environment

PyTorch 1.12 compiled with CUDA PT-XLA 1.12 compiled with CUDA

Additional context

Single Node distributed training fails with 1.12 binaries. An n-gpu process launches nxn processes. The reason for this is that starting from 1.12, import torch directly initializes cuda on all visible devices. Prior to this, you could set the env variable within the code and trick each process to believe that there is only one GPU available. XLA uses this trick to do single node distributed training as seen here . Essentially import torch exposes all the CUDA devices which means setting CUDA_VISIBLE_DEVICES post facto is irrelevant. This implies that every process can now see every other device , and is launching n processes internally, but there are n of them , so its nXn. Prior to this every process was only seeing the CUDA device set by CUDA_VISIBLE_DEVICES in each particular process , so process 0 , would see 0, 2 would see 2 and so on. But now process 0 is seeing 0-n-1 , 1 is seeing 1-n-1 and so on. This has happened because PyTorch upstream has changed when cuda devices are initialized to the top-level(import torch) instead of when cuda devices are actually used or tensors are allocated, and this broke the assumption that xla was using to do single node distributed training. This error does NOT exist in the images released publicly because they are built with PyTorch + CPU and not PyTorch + CUDA, which means that when they do import torch it doesn't check CUDA devices. .

Hence this is a problem unique to CUDA compilation of PyTorch. I think the solution here is to change how single node distributed training pathway works and GPU IDs are initialized and set but wanted to have a broader discussion around that.

Special Thanks to @ymwangg for productive discussions on this issue.

Appendix

Consider this script which when run on 8 GPUs outputs 8 instead of 4.

import os
import torch
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
print(torch.cuda.device_count())

I think this is the reason of why so many processes are being spawned.

ymwangg commented 2 years ago

@codeislife99 FYI previous thread about the same issue https://github.com/pytorch/xla/issues/3347. It looks like pytorch recently requires CUDA_VISIBLE_DEVICES to be set prior to importing torch.

codeislife99 commented 2 years ago

@JackCaoG Do you have any insights on this issue?

JackCaoG commented 2 years ago

Hmm... can you set CUDA_VISIBLE_DEVICES before importing pytorch? In our use case we always use non-cuda version of the pytorch so we have not seen this issue before. Is it inconvenient because figuring out CUDA_VISIBLE_DEVICES is non-trival per process?

How does pytorch solve it through, do they not have n process for n GPU?

codeislife99 commented 2 years ago

PyTorch can do it because they are not doing distributed training using the trick XLA is doing it. For them it's native to torch itself, they are using DDP modules instead of setting a different visible device to each device. Can you elaborate on how CUDA_VISIBLE_DEVICES can be set prior to importing torch ? the code is in such a way that the processes are spawned first and then the CUDA_DEVICE is attached to them

JackCaoG commented 2 years ago

well, we also have a ddo module contributed by @hjm-aws in https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_backend.py and I think we get it to work with gpu at some point. However AFAIK this ddp module is forward the cc ops to xla python version of cc ops.

JackCaoG commented 2 years ago

alternative is to use the non-gpu version of the pytorch. We are trying to migrate to the new pjrt runtime, the CUDA_VISIBLE_DEVICES trick was mostly for XRT. We haven't plan on the PJRT:GPU path too much yet. @cicirori in case you run into the same issue.

ymwangg commented 2 years ago

@codeislife99 It looks like this is a bug in pytorch 1.12 and has been fixed now https://github.com/pytorch/pytorch/issues/80876. We need to figure out how to not rely on CUDA_VISIBLE_DEVICES to set device in the future as Jack suggested.

codeislife99 commented 2 years ago

Excellent! That should solve the problem for now. I will try it out and update it here. But should work for sure. It is true that the trick of CUDA_VISIBLE_DEVICES seems like a hack and a better solution would improve the robustness of the code.