pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.49k stars 483 forks source link

Using CC ops with mark_sharding API throws an error. #6647

Open amithrm opened 8 months ago

amithrm commented 8 months ago

🐛 Describe the bug

The crash seen is the following:

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
F0000 00:00:1709131242.311197 36940 hlo_sharding.cc:1034] Check failed: IsTuple() Check failure stack trace:
@ 0x7f1e46d752d9 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal() @ 0x7f1e40ed9700 xla::HloSharding::GetSubSharding() @ 0x7f1e41cadd35 xla::ShardingPropagation::InferShardingFromOperands() @ 0x7f1e41cb1cec xla::ShardingPropagation::Run()::{lambda()#3}::operator()() @ 0x7f1e41cb5d43 xla::ShardingPropagation::Run()
@ 0x7f1e41c98355 xla::HloPassPipeline::RunHelper()
@ 0x7f1e41c9933a xla::HloPassPipeline::RunPassesInternal<>()
@ 0x7f1e41c99fa4 xla::HloPassPipeline::Run()
@ 0x7f1e41100d49 neuron::HloOptimization() @ 0x7f1e410a3ab9 neuron::Optimize() @ 0x7f1e4109f07e neuron::PJRT_Client_Compile() @ 0x7f1e410a0638 neuron::Decorator<>::wrapper() @ 0x7f1e51d966c5 xla::InitializeArgsAndCompile() @ 0x7f1e51d969e0 xla::PjRtCApiClient::Compile() @ 0x7f1e4d3411e6 torch_xla::runtime::PjRtComputationClient::Compile() @ 0x7f1e4d14853e torch_xla::XLAGraphExecutor::Compile() @ 0x7f1e4d149f49 torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal() @ 0x7f1e4d14a58b torch_xla::XLAGraphExecutor::SyncTensorsGraph() @ 0x7f1e4d14a9b8 torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph() @ 0x7f1e4cf1928a torch_xla::(anonymous namespace)::StepMarker() @ 0x7f1e4cf196c6 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() @ 0x7f1e4cef6ed0 pybind11::cpp_function::dispatcher() @ 0x5d5499 PyCFunction_Call Aborted (core dumped)

A simple example to reproduce the bug is attached below:

import os
import numpy as np
import torch
import torch_xla
import torch_xla.runtime as xr

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
os.environ["TF_CPP_VMODULE"] ='hlo_optimization=5'
# Enable XLA SPMD execution mode.
os.environ["XLA_IR_DEBUG"] = "1"
os.environ["XLA_FLAGS"]="--xla_force_host_platform_device_count=32 --xla_dump_hlo_as_text --xla_dump_hlo_as_proto --xla_dump_to=./xla_dump --xla_dump_hlo_pass_re='.*spmd.*'"
xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
import torch_xla.core.xla_model as xm
import torch_xla.experimental.pjrt_backend
import torch_xla.experimental.pjrt as pjrt

os.environ['NEURON_RT_NUM_CORES']='32'
os.environ['NEURON_PJRT_PROCESS_INDEX'] = '0'
os.environ['NEURON_PJRT_PROCESSES_NUM_DEVICES'] = '32'
os.environ['WORLD_SIZE'] = '1'

num_devices = xr.global_runtime_device_count()
print(f'num device: {num_devices}')
mesh_shape = (1, num_devices)
device_ids = np.array(range(num_devices))
# axis_names 'x' nad 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

lin = torch.nn.Linear(8192, 32768, bias=False).to(xm.xla_device())
lin2 = torch.nn.Linear(32768, 8192, bias=False).to(xm.xla_device())
xs.mark_sharding(lin.weight, mesh, ('y', 'x'))
xs.mark_sharding(lin2.weight, mesh, ('x', 'y'))
lin.train()
lin2.train()

print(mesh.get_logical_mesh(), mesh.shape())

t1 = torch.randn(64, 8192).to(xm.xla_device())
t2 = lin(t1)
t3 = lin2(t2)
xs.mark_sharding(t3, mesh, (None, None))
xm.mark_step()
print(t3)
t4 = xm.all_reduce('sum', t3)
xm.mark_step()
print(t4)
`

Versions

Versions of relevant libraries: [pip3] numpy==1.24.4 [pip3] torch==2.1.2 [pip3] torch-neuronx==2.1.1.2.0.1b0 [pip3] torch-xla==2.1.2 [pip3] torchvision==0.16.2

JackCaoG commented 8 months ago

I think what you are trying to do is

  1. use SPMD to shard the HLO belong to current pipeline stage
  2. use cc ops to communicate across all of the hosts

I think this won't work out of the box because

  1. under your SPMD setup PyTorch/XLA will only start a single process per host which owns all of the XLA devices in current host.
  2. However under your cc ops setup, you would need to start x process per host and init PJRT runtime in a way so it recognize all of the devices across the host.

The problem here is that I don't think there is an easy to change the PJRT device config on the go. @will-cromar @yeounoh in cases you guys has some better suggestions.

@baoleai I remembered you guys mentioned something about SPMD + pp, wondering if you guys has some insight as well.

baoleai commented 8 months ago

Currently, SPMD cannot support communication operators at the Python layer. When combining SPMD-TP and PP, we made numerous changes to xla and the openxla spmd pass to support send/recv @yitongh . Supporting the allreduce communication operator might be more complicated.

yitongh commented 8 months ago

Based on previous experience, you will need to do the following things on GPU:

  1. Support communication operations such as all-reduce on the Python side within SPMD. For example, support all-reduce in sharding_propagation.cc.
  2. When invoking NCCL communication, correctly handle the communication ranks for all-reduce, because the CollectiveOpGroupMode in the SPMD environment is different from that in the replicate mode, and some hack conversions are needed.

Even with the above handling, the all-reduce operator is currently not well-suited to handle sharded inputs and can only function as a replicated operation.

Similar handling may be required in the TPU environment. Overall, supporting Python-side communication in the SPMD environment doesn't seem to have any particularly elegant solutions at the moment. Perhaps, as JackCaoG suggested, changing the configuration of the PJRT device might be a good approach.

amithrm commented 8 months ago

@baoleai @yitongh is the send/recv using XLA Send/Recv ? We are using all-reduce instead of send/recv to simplify our stack and we can assume that only non-sharded tensors will be passed in.

Can we use any way to skip sharding_propagation pass ? This can be an isolated graph (cut off using mark_Steps) and we can use custom_call or any attribute to skip "peeking" into the all-reduce

amithrm commented 8 months ago

@JackCaoG For the cc ops set-up, why do we need ti set up PjRT in a different way? All we need is the graph with the correct replica groups correct? (this can be borrowed from mesh during SPMD set-up). The PjRT runtime would just execute this on all "threads" (we dont need these to be different processes ) and the all-reduce would look any other all-reduce from a SPMD partitioner pass.