Closed ronghanghu closed 2 years ago
Thanks @ronghanghu! I will take a look at the issue this week.
Quick update: @ronghanghu I was able to reproduce your issue. We are drilling into the details to learn more.
@miladm Sounds great, thanks for the update!
able to confirm that hlo
HloModule IrToHlo.87
%AddComputation.70 (x.71: f32[], y.72: f32[]) -> f32[] {
%x.71 = f32[] parameter(0)
%y.72 = f32[] parameter(1)
ROOT %add.73 = f32[] add(f32[] %x.71, f32[] %y.72)
}
%AddComputation.81 (x.82: f32[], y.83: f32[]) -> f32[] {
%x.82 = f32[] parameter(0)
%y.83 = f32[] parameter(1)
ROOT %add.84 = f32[] add(f32[] %x.82, f32[] %y.83)
}
ENTRY %IrToHlo.87 (p0.1: s64[], p1.60: f32[]) -> (f32[199665], f32[25557120], f32[199665], f32[]) {
%constant.4 = s64[] constant(2531011)
%constant.2 = s64[] constant(214013)
%p0.1 = s64[] parameter(0)
%multiply.3 = s64[] multiply(s64[] %constant.2, s64[] %p0.1)
%add.5 = s64[] add(s64[] %constant.4, s64[] %multiply.3)
%convert.16 = u64[] convert(s64[] %add.5)
%reshape.20 = u64[1]{0} reshape(u64[] %convert.16)
%constant.17 = u64[] constant(0)
%reshape.21 = u64[1]{0} reshape(u64[] %constant.17)
%concatenate.22 = u64[2]{0} concatenate(u64[1]{0} %reshape.20, u64[1]{0} %reshape.21), dimensions={0}
%rng-bit-generator.23 = (u64[2]{0}, u32[99833,2]{1,0}) rng-bit-generator(u64[2]{0} %concatenate.22), algorithm=rng_default
%get-tuple-element.25 = u64[2]{0} get-tuple-element((u64[2]{0}, u32[99833,2]{1,0}) %rng-bit-generator.23), index=0
%p1.60 = f32[] parameter(1)
%constant.11 = f32[] constant(0)
%reshape.12 = f32[1]{0} reshape(f32[] %constant.11)
%broadcast.13 = f32[1]{0} broadcast(f32[1]{0} %reshape.12), dimensions={0}
%reshape.14 = f32[] reshape(f32[1]{0} %broadcast.13)
%broadcast.15 = f32[199665]{0} broadcast(f32[] %reshape.14), dimensions={}
%constant.43 = f32[] constant(6.28318548)
%broadcast.44 = f32[99833,1]{1,0} broadcast(f32[] %constant.43), dimensions={}
%get-tuple-element.24 = u32[99833,2]{1,0} get-tuple-element((u64[2]{0}, u32[99833,2]{1,0}) %rng-bit-generator.23), index=1
%constant.26 = u32[] constant(9)
%broadcast.27 = u32[99833,2]{1,0} broadcast(u32[] %constant.26), dimensions={}
%shift-right-logical.28 = u32[99833,2]{1,0} shift-right-logical(u32[99833,2]{1,0} %get-tuple-element.24, u32[99833,2]{1,0} %broadcast.27)
%convert.29 = f32[99833,2]{1,0} convert(u32[99833,2]{1,0} %shift-right-logical.28)
%constant.30 = f32[] constant(1.1920929e-07)
%broadcast.31 = f32[99833,2]{1,0} broadcast(f32[] %constant.30), dimensions={}
%multiply.32 = f32[99833,2]{1,0} multiply(f32[99833,2]{1,0} %convert.29, f32[99833,2]{1,0} %broadcast.31)
%constant.19 = f32[] constant(1)
%constant.18 = f32[] constant(0)
%subtract.33 = f32[] subtract(f32[] %constant.19, f32[] %constant.18)
%broadcast.34 = f32[99833,2]{1,0} broadcast(f32[] %subtract.33), dimensions={}
%multiply.35 = f32[99833,2]{1,0} multiply(f32[99833,2]{1,0} %multiply.32, f32[99833,2]{1,0} %broadcast.34)
%broadcast.36 = f32[99833,2]{1,0} broadcast(f32[] %constant.18), dimensions={}
%add.37 = f32[99833,2]{1,0} add(f32[99833,2]{1,0} %multiply.35, f32[99833,2]{1,0} %broadcast.36)
%slice.39 = f32[99833,1]{1,0} slice(f32[99833,2]{1,0} %add.37), slice={[0:99833], [1:2]}
%multiply.45 = f32[99833,1]{1,0} multiply(f32[99833,1]{1,0} %broadcast.44, f32[99833,1]{1,0} %slice.39)
%sine.51 = f32[99833,1]{1,0} sine(f32[99833,1]{1,0} %multiply.45)
%constant.46 = f32[] constant(-2)
%broadcast.48 = f32[99833,1]{1,0} broadcast(f32[] %constant.46), dimensions={}
%slice.38 = f32[99833,1]{1,0} slice(f32[99833,2]{1,0} %add.37), slice={[0:99833], [0:1]}
%constant.40 = f32[] constant(1e-07)
%broadcast.41 = f32[99833,1]{1,0} broadcast(f32[] %constant.40), dimensions={}
%maximum.42 = f32[99833,1]{1,0} maximum(f32[99833,1]{1,0} %slice.38, f32[99833,1]{1,0} %broadcast.41)
%log.47 = f32[99833,1]{1,0} log(f32[99833,1]{1,0} %maximum.42)
%multiply.49 = f32[99833,1]{1,0} multiply(f32[99833,1]{1,0} %broadcast.48, f32[99833,1]{1,0} %log.47)
%sqrt.50 = f32[99833,1]{1,0} sqrt(f32[99833,1]{1,0} %multiply.49)
%multiply.52 = f32[99833,1]{1,0} multiply(f32[99833,1]{1,0} %sine.51, f32[99833,1]{1,0} %sqrt.50)
%cosine.53 = f32[99833,1]{1,0} cosine(f32[99833,1]{1,0} %multiply.45)
%multiply.54 = f32[99833,1]{1,0} multiply(f32[99833,1]{1,0} %cosine.53, f32[99833,1]{1,0} %sqrt.50)
%concatenate.55 = f32[99833,2]{1,0} concatenate(f32[99833,1]{1,0} %multiply.52, f32[99833,1]{1,0} %multiply.54), dimensions={1}
%reshape.56 = f32[199666]{0} reshape(f32[99833,2]{1,0} %concatenate.55)
%slice.57 = f32[199665]{0} slice(f32[199666]{0} %reshape.56), slice={[0:199665]}
%constant.6 = f32[] constant(1)
%reshape.7 = f32[1]{0} reshape(f32[] %constant.6)
%broadcast.8 = f32[1]{0} broadcast(f32[1]{0} %reshape.7), dimensions={0}
%reshape.9 = f32[] reshape(f32[1]{0} %broadcast.8)
%broadcast.10 = f32[199665]{0} broadcast(f32[] %reshape.9), dimensions={}
%multiply.58 = f32[199665]{0} multiply(f32[199665]{0} %slice.57, f32[199665]{0} %broadcast.10)
%add.59 = f32[199665]{0} add(f32[199665]{0} %broadcast.15, f32[199665]{0} %multiply.58)
%broadcast.61 = f32[199665]{0} broadcast(f32[] %p1.60), dimensions={}
%add.62 = f32[199665]{0} add(f32[199665]{0} %add.59, f32[199665]{0} %broadcast.61)
%all-gather.63 = f32[25557120]{0} all-gather(f32[199665]{0} %add.62), replica_groups={}, dimensions={0}
%constant.64 = s32[] constant(0)
%broadcast.65 = s32[1]{0} broadcast(s32[] %constant.64), dimensions={}
%gather.66 = f32[] gather(f32[25557120]{0} %all-gather.63, s32[1]{0} %broadcast.65), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=0, slice_sizes={1}
%multiply.67 = f32[] multiply(f32[] %p1.60, f32[] %gather.66)
%broadcast.68 = f32[25557120]{0} broadcast(f32[] %multiply.67), dimensions={}
%add.69 = f32[25557120]{0} add(f32[25557120]{0} %all-gather.63, f32[25557120]{0} %broadcast.68)
%reduce-scatter.74 = f32[199665]{0} reduce-scatter(f32[25557120]{0} %add.69), replica_groups={}, constrain_layout=true, dimensions={0}, to_apply=%AddComputation.70
%constant.75 = s32[] constant(0)
%broadcast.76 = s32[1]{0} broadcast(s32[] %constant.75), dimensions={}
%gather.77 = f32[] gather(f32[199665]{0} %reduce-scatter.74, s32[1]{0} %broadcast.76), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=0, slice_sizes={1}
%multiply.78 = f32[] multiply(f32[] %multiply.67, f32[] %gather.77)
%constant.80 = s32[] constant(199665)
%constant.79 = f32[] constant(0)
%reduce.85 = f32[] reduce(f32[199665]{0} %reduce-scatter.74, f32[] %constant.79), dimensions={0}, to_apply=%AddComputation.81
ROOT %tuple.86 = (f32[199665]{0}, f32[25557120]{0}, f32[199665]{0}, f32[]) tuple(f32[199665]{0} %add.59, f32[25557120]{0} %all-gather.63, f32[199665]{0} %reduce-scatter.74, f32[] %reduce.85)
}
crashed the xrt server
PC: @ 0x7f7cea32918b (unknown) raise
@ 0x7f7c44095d3a 992 (unknown)
@ 0x7f7cea329210 3968 (unknown)
@ 0x7f7c53093bd0 16 tensorflow::internal::LogMessageFatal::~LogMessageFatal()
@ 0x7f7c4d708923 592 tensorflow::tpu::TpuProgramGroup::Initialize()
@ 0x7f7c4d6c6ebe 1488 tensorflow::tpu::TpuCompilationCacheExternal::InitializeEntry()
@ 0x7f7c4d716db1 800 tensorflow::tpu::TpuCompilationCacheInterface::CompileIfKeyAbsentHelper()
@ 0x7f7c4d7168af 496 tensorflow::tpu::TpuCompilationCacheInterface::CompileIfKeyAbsent()
@ 0x7f7c49578ed4 912 tensorflow::XRTCompileOp::Compute()
@ 0x7f7c52b458a1 2128 tensorflow::(anonymous namespace)::ExecutorState<>::Process()
@ 0x7f7c52b47634 48 std::_Function_handler<>::_M_invoke()
@ 0x7f7c5305ccb2 128 Eigen::ThreadPoolTempl<>::WorkerLoop()
@ 0x7f7c530448cc 80 tensorflow::(anonymous namespace)::PThread::ThreadFn()
@ 0x7f7cea2c9609 (unknown) start_thread
Hard to tell why it crashed but might just be a OOM. I will circleback next Monday.
I think the real error is
2022-03-30 04:23:53.999554: E tensorflow/core/tpu/kernels/tpu_compilation_cache_external.cc:113] during context [post-optimization]: HloModule has a mix of layout constrained and unconstrained AllReduce instructions.
I will ask xla team to take a look.
I think the issue is we missed the layout pinning for reduce_scatter
and all_gather
(we pin the layout for all_reduce). I will work on a fix today.
Thanks a lot @JackCaoG!
🐛 Bug
The nightly PyTorch XLA build (20220308 for torch, torchvision, torch_xla, and libtpu) gives an unexpected GRPC error when
all_gather
is used together withreduced_scatter
.After some debugging, I think this GRPC error is related to the all_gather update in https://github.com/pytorch/xla/pull/3275. The error goes away after I revert to the older all_gather implementation via all_reduce (in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615). Also, the weird thing is that the GRPC error happens only when all_gather is used together with reduce_scatter -- it doesn't happen when either op is used alone.
Specifically, when I run
xm.all_gather
on a v3-128 pod followed byxm.reduce_scatter
, it consistently gives me the following error message:The training process hangs and later crashes after this error happens.
This error happens on both v3-128 and v3-256. However, it doesn't happen on v3-8 or v3-32 and doesn't happen when using the older all_gather implementation via all_reduce. See details below.
To Reproduce
Allocate a v3-128 TPU VM pod (e.g.
ronghanghu-v3-128
below) fromtpu-vm-pt-1.10
runtime and install the nightly 20220308 build on all VM nodes (e.g. throughgcloud alpha compute tpus tpu-vm ssh --worker all
) viaSave the following content to a python file (e.g.
/home/ronghanghu/test_all_gather_reduce_scatter.py
below) and scp it to all the VM nodes.def _mp_fn(index): world_size = xm.xrt_world_size() assert world_size >= 128, "this should be tested on a v3-128 or v3-256" t1 = torch.randn(199665, device=xm.xla_device()) t2 = xm.all_gather(t1).flatten() t3 = xm.reduce_scatter(xm.REDUCE_SUM, t2, scale=1.0, scatter_dim=0, shard_count=world_size) t4 = t3.sum() xm.mark_step() print(f"t4: {t4}")
if name == "main": xmp.spawn(_mp_fn, args=(), nprocs=8)
python3 -m torch_xla.distributed.xla_dist --tpu=ronghanghu-v3-128 --restart-tpuvm-pod-server -- \ python3 /home/ronghanghu/test_all_gather_reduce_scatter.py
def old_all_gather(value, dim=0, groups=None): """ This is the older all_gather implementation via all_reduce in PyTorch XLA 1.10 in https://github.com/pytorch/xla/blob/v1.10.0/torch_xla/core/xla_model.py#L583-L615 """ if dim < 0: dim = value.dim() + dim size = value.size(dim) padding = [0] (2 value.dim()) ordinal = xm.get_ordinal() if groups is None: left, right = ordinal, xm.xrt_world_size() - 1 - ordinal else: ordinals = dict() for g in groups: for i, x in enumerate(g): ordinals[x] = (i, len(g) - 1 - i) left, right = ordinals[ordinal] idx = value.dim() - 1 - dim padding[2 idx] = left size padding[2 idx + 1] = right size return xm.all_reduce(xm.REDUCE_SUM, F.pad(value, padding), groups=groups)
xm.all_gather = old_all_gather