pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 461 forks source link

new `all_gather` together w/ `reduce_scatter` causes GRPC error on v3-128 and v3-256 (nightly 20220308) #3423

Closed ronghanghu closed 2 years ago

ronghanghu commented 2 years ago

🐛 Bug

The nightly PyTorch XLA build (20220308 for torch, torchvision, torch_xla, and libtpu) gives an unexpected GRPC error when all_gather is used together with reduced_scatter.

After some debugging, I think this GRPC error is related to the all_gather update in https://github.com/pytorch/xla/pull/3275. The error goes away after I revert to the older all_gather implementation via all_reduce (in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615). Also, the weird thing is that the GRPC error happens only when all_gather is used together with reduce_scatter -- it doesn't happen when either op is used alone.

Specifically, when I run xm.all_gather on a v3-128 pod followed by xm.reduce_scatter, it consistently gives me the following error message:

2022-03-09 06:20:03 10.164.15.223 [8] 2022-03-09 06:20:03.763131: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Connection reset by peer" and grpc_error_string = "{"created":"@1646806803.763047345","description":"Error received from peer ipv4:10.164.15.223:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Connection reset by peer","grpc_status":14}", maybe retrying the RPC

The training process hangs and later crashes after this error happens.

This error happens on both v3-128 and v3-256. However, it doesn't happen on v3-8 or v3-32 and doesn't happen when using the older all_gather implementation via all_reduce. See details below.

To Reproduce

  1. Allocate a v3-128 TPU VM pod (e.g. ronghanghu-v3-128 below) from tpu-vm-pt-1.10 runtime and install the nightly 20220308 build on all VM nodes (e.g. through gcloud alpha compute tpus tpu-vm ssh --worker all) via

    sudo pip3 uninstall -y torch torchvision torch_xla libtpu_nightly
    sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220308-cp38-cp38-linux_x86_64.whl
    sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220308-cp38-cp38-linux_x86_64.whl
    sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220308-cp38-cp38-linux_x86_64.whl
    sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220308-py3-none-any.whl
  2. Save the following content to a python file (e.g. /home/ronghanghu/test_all_gather_reduce_scatter.py below) and scp it to all the VM nodes.

    
    import torch
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.xla_multiprocessing as xmp

def _mp_fn(index): world_size = xm.xrt_world_size() assert world_size >= 128, "this should be tested on a v3-128 or v3-256" t1 = torch.randn(199665, device=xm.xla_device()) t2 = xm.all_gather(t1).flatten() t3 = xm.reduce_scatter(xm.REDUCE_SUM, t2, scale=1.0, scatter_dim=0, shard_count=world_size) t4 = t3.sum() xm.mark_step() print(f"t4: {t4}")

if name == "main": xmp.spawn(_mp_fn, args=(), nprocs=8)


3. Run this file on the v3-128 pod:

python3 -m torch_xla.distributed.xla_dist --tpu=ronghanghu-v3-128 --restart-tpuvm-pod-server -- \ python3 /home/ronghanghu/test_all_gather_reduce_scatter.py


## Expected behavior

The GRPC error should not happen.

## Environment

 - Reproducible on XLA backend [CPU/TPU]: v3-128 TPU VM
 - torch_xla version: 20220308 nightly from `tpu-vm-pt-1.10` (see Step 1 above)

## Additional context

If we revert `xm.all_gather` to the older version implemented via all_reduce (see below), then the GRPC error does not happen.

def old_all_gather(value, dim=0, groups=None): """ This is the older all_gather implementation via all_reduce in PyTorch XLA 1.10 in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615 """ if dim < 0: dim = value.dim() + dim size = value.size(dim) padding = [0] (2 value.dim()) ordinal = xm.get_ordinal() if groups is None: left, right = ordinal, xm.xrt_world_size() - 1 - ordinal else: ordinals = dict() for g in groups: for i, x in enumerate(g): ordinals[x] = (i, len(g) - 1 - i) left, right = ordinals[ordinal] idx = value.dim() - 1 - dim padding[2 idx] = left size padding[2 idx + 1] = right size return xm.all_reduce(xm.REDUCE_SUM, F.pad(value, padding), groups=groups)

xm.all_gather = old_all_gather

miladm commented 2 years ago

Thanks @ronghanghu! I will take a look at the issue this week.

miladm commented 2 years ago

Quick update: @ronghanghu I was able to reproduce your issue. We are drilling into the details to learn more.

ronghanghu commented 2 years ago

@miladm Sounds great, thanks for the update!

JackCaoG commented 2 years ago

able to confirm that hlo

HloModule IrToHlo.87

%AddComputation.70 (x.71: f32[], y.72: f32[]) -> f32[] {
  %x.71 = f32[] parameter(0)
  %y.72 = f32[] parameter(1)
  ROOT %add.73 = f32[] add(f32[] %x.71, f32[] %y.72)
}

%AddComputation.81 (x.82: f32[], y.83: f32[]) -> f32[] {
  %x.82 = f32[] parameter(0)
  %y.83 = f32[] parameter(1)
  ROOT %add.84 = f32[] add(f32[] %x.82, f32[] %y.83)
}

ENTRY %IrToHlo.87 (p0.1: s64[], p1.60: f32[]) -> (f32[199665], f32[25557120], f32[199665], f32[]) {
  %constant.4 = s64[] constant(2531011)
  %constant.2 = s64[] constant(214013)
  %p0.1 = s64[] parameter(0)
  %multiply.3 = s64[] multiply(s64[] %constant.2, s64[] %p0.1)
  %add.5 = s64[] add(s64[] %constant.4, s64[] %multiply.3)
  %convert.16 = u64[] convert(s64[] %add.5)
  %reshape.20 = u64[1]{0} reshape(u64[] %convert.16)
  %constant.17 = u64[] constant(0)
  %reshape.21 = u64[1]{0} reshape(u64[] %constant.17)
  %concatenate.22 = u64[2]{0} concatenate(u64[1]{0} %reshape.20, u64[1]{0} %reshape.21), dimensions={0}
  %rng-bit-generator.23 = (u64[2]{0}, u32[99833,2]{1,0}) rng-bit-generator(u64[2]{0} %concatenate.22), algorithm=rng_default
  %get-tuple-element.25 = u64[2]{0} get-tuple-element((u64[2]{0}, u32[99833,2]{1,0}) %rng-bit-generator.23), index=0
  %p1.60 = f32[] parameter(1)
  %constant.11 = f32[] constant(0)
  %reshape.12 = f32[1]{0} reshape(f32[] %constant.11)
  %broadcast.13 = f32[1]{0} broadcast(f32[1]{0} %reshape.12), dimensions={0}
  %reshape.14 = f32[] reshape(f32[1]{0} %broadcast.13)
  %broadcast.15 = f32[199665]{0} broadcast(f32[] %reshape.14), dimensions={}
  %constant.43 = f32[] constant(6.28318548)
  %broadcast.44 = f32[99833,1]{1,0} broadcast(f32[] %constant.43), dimensions={}
  %get-tuple-element.24 = u32[99833,2]{1,0} get-tuple-element((u64[2]{0}, u32[99833,2]{1,0}) %rng-bit-generator.23), index=1
  %constant.26 = u32[] constant(9)
  %broadcast.27 = u32[99833,2]{1,0} broadcast(u32[] %constant.26), dimensions={}
  %shift-right-logical.28 = u32[99833,2]{1,0} shift-right-logical(u32[99833,2]{1,0} %get-tuple-element.24, u32[99833,2]{1,0} %broadcast.27)
  %convert.29 = f32[99833,2]{1,0} convert(u32[99833,2]{1,0} %shift-right-logical.28)
  %constant.30 = f32[] constant(1.1920929e-07)
  %broadcast.31 = f32[99833,2]{1,0} broadcast(f32[] %constant.30), dimensions={}
  %multiply.32 = f32[99833,2]{1,0} multiply(f32[99833,2]{1,0} %convert.29, f32[99833,2]{1,0} %broadcast.31)
  %constant.19 = f32[] constant(1)
  %constant.18 = f32[] constant(0)
  %subtract.33 = f32[] subtract(f32[] %constant.19, f32[] %constant.18)
  %broadcast.34 = f32[99833,2]{1,0} broadcast(f32[] %subtract.33), dimensions={}
  %multiply.35 = f32[99833,2]{1,0} multiply(f32[99833,2]{1,0} %multiply.32, f32[99833,2]{1,0} %broadcast.34)
  %broadcast.36 = f32[99833,2]{1,0} broadcast(f32[] %constant.18), dimensions={}
  %add.37 = f32[99833,2]{1,0} add(f32[99833,2]{1,0} %multiply.35, f32[99833,2]{1,0} %broadcast.36)
  %slice.39 = f32[99833,1]{1,0} slice(f32[99833,2]{1,0} %add.37), slice={[0:99833], [1:2]}
  %multiply.45 = f32[99833,1]{1,0} multiply(f32[99833,1]{1,0} %broadcast.44, f32[99833,1]{1,0} %slice.39)
  %sine.51 = f32[99833,1]{1,0} sine(f32[99833,1]{1,0} %multiply.45)
  %constant.46 = f32[] constant(-2)
  %broadcast.48 = f32[99833,1]{1,0} broadcast(f32[] %constant.46), dimensions={}
  %slice.38 = f32[99833,1]{1,0} slice(f32[99833,2]{1,0} %add.37), slice={[0:99833], [0:1]}
  %constant.40 = f32[] constant(1e-07)
  %broadcast.41 = f32[99833,1]{1,0} broadcast(f32[] %constant.40), dimensions={}
  %maximum.42 = f32[99833,1]{1,0} maximum(f32[99833,1]{1,0} %slice.38, f32[99833,1]{1,0} %broadcast.41)
  %log.47 = f32[99833,1]{1,0} log(f32[99833,1]{1,0} %maximum.42)
  %multiply.49 = f32[99833,1]{1,0} multiply(f32[99833,1]{1,0} %broadcast.48, f32[99833,1]{1,0} %log.47)
  %sqrt.50 = f32[99833,1]{1,0} sqrt(f32[99833,1]{1,0} %multiply.49)
  %multiply.52 = f32[99833,1]{1,0} multiply(f32[99833,1]{1,0} %sine.51, f32[99833,1]{1,0} %sqrt.50)
  %cosine.53 = f32[99833,1]{1,0} cosine(f32[99833,1]{1,0} %multiply.45)
  %multiply.54 = f32[99833,1]{1,0} multiply(f32[99833,1]{1,0} %cosine.53, f32[99833,1]{1,0} %sqrt.50)
  %concatenate.55 = f32[99833,2]{1,0} concatenate(f32[99833,1]{1,0} %multiply.52, f32[99833,1]{1,0} %multiply.54), dimensions={1}
  %reshape.56 = f32[199666]{0} reshape(f32[99833,2]{1,0} %concatenate.55)
  %slice.57 = f32[199665]{0} slice(f32[199666]{0} %reshape.56), slice={[0:199665]}
  %constant.6 = f32[] constant(1)
  %reshape.7 = f32[1]{0} reshape(f32[] %constant.6)
  %broadcast.8 = f32[1]{0} broadcast(f32[1]{0} %reshape.7), dimensions={0}
  %reshape.9 = f32[] reshape(f32[1]{0} %broadcast.8)
  %broadcast.10 = f32[199665]{0} broadcast(f32[] %reshape.9), dimensions={}
  %multiply.58 = f32[199665]{0} multiply(f32[199665]{0} %slice.57, f32[199665]{0} %broadcast.10)
  %add.59 = f32[199665]{0} add(f32[199665]{0} %broadcast.15, f32[199665]{0} %multiply.58)
  %broadcast.61 = f32[199665]{0} broadcast(f32[] %p1.60), dimensions={}
  %add.62 = f32[199665]{0} add(f32[199665]{0} %add.59, f32[199665]{0} %broadcast.61)
  %all-gather.63 = f32[25557120]{0} all-gather(f32[199665]{0} %add.62), replica_groups={}, dimensions={0}
  %constant.64 = s32[] constant(0)
  %broadcast.65 = s32[1]{0} broadcast(s32[] %constant.64), dimensions={}
  %gather.66 = f32[] gather(f32[25557120]{0} %all-gather.63, s32[1]{0} %broadcast.65), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=0, slice_sizes={1}
  %multiply.67 = f32[] multiply(f32[] %p1.60, f32[] %gather.66)
  %broadcast.68 = f32[25557120]{0} broadcast(f32[] %multiply.67), dimensions={}
  %add.69 = f32[25557120]{0} add(f32[25557120]{0} %all-gather.63, f32[25557120]{0} %broadcast.68)
  %reduce-scatter.74 = f32[199665]{0} reduce-scatter(f32[25557120]{0} %add.69), replica_groups={}, constrain_layout=true, dimensions={0}, to_apply=%AddComputation.70
  %constant.75 = s32[] constant(0)
  %broadcast.76 = s32[1]{0} broadcast(s32[] %constant.75), dimensions={}
  %gather.77 = f32[] gather(f32[199665]{0} %reduce-scatter.74, s32[1]{0} %broadcast.76), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=0, slice_sizes={1}
  %multiply.78 = f32[] multiply(f32[] %multiply.67, f32[] %gather.77)
  %constant.80 = s32[] constant(199665)
  %constant.79 = f32[] constant(0)
  %reduce.85 = f32[] reduce(f32[199665]{0} %reduce-scatter.74, f32[] %constant.79), dimensions={0}, to_apply=%AddComputation.81
  ROOT %tuple.86 = (f32[199665]{0}, f32[25557120]{0}, f32[199665]{0}, f32[]) tuple(f32[199665]{0} %add.59, f32[25557120]{0} %all-gather.63, f32[199665]{0} %reduce-scatter.74, f32[] %reduce.85)
}

crashed the xrt server

PC: @     0x7f7cea32918b  (unknown)  raise
    @     0x7f7c44095d3a        992  (unknown)
    @     0x7f7cea329210       3968  (unknown)
    @     0x7f7c53093bd0         16  tensorflow::internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f7c4d708923        592  tensorflow::tpu::TpuProgramGroup::Initialize()
    @     0x7f7c4d6c6ebe       1488  tensorflow::tpu::TpuCompilationCacheExternal::InitializeEntry()
    @     0x7f7c4d716db1        800  tensorflow::tpu::TpuCompilationCacheInterface::CompileIfKeyAbsentHelper()
    @     0x7f7c4d7168af        496  tensorflow::tpu::TpuCompilationCacheInterface::CompileIfKeyAbsent()
    @     0x7f7c49578ed4        912  tensorflow::XRTCompileOp::Compute()
    @     0x7f7c52b458a1       2128  tensorflow::(anonymous namespace)::ExecutorState<>::Process()
    @     0x7f7c52b47634         48  std::_Function_handler<>::_M_invoke()
    @     0x7f7c5305ccb2        128  Eigen::ThreadPoolTempl<>::WorkerLoop()
    @     0x7f7c530448cc         80  tensorflow::(anonymous namespace)::PThread::ThreadFn()
    @     0x7f7cea2c9609  (unknown)  start_thread

Hard to tell why it crashed but might just be a OOM. I will circleback next Monday.

JackCaoG commented 2 years ago

I think the real error is

2022-03-30 04:23:53.999554: E tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:113] during context [post-optimization]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.

I will ask xla team to take a look.

JackCaoG commented 2 years ago

I think the issue is we missed the layout pinning for reduce_scatter and all_gather (we pin the layout for all_reduce). I will work on a fix today.

ronghanghu commented 2 years ago

Thanks a lot @JackCaoG!