Open amithrm opened 8 months ago
I think what you are trying to do is
I think this won't work out of the box because
The problem here is that I don't think there is an easy to change the PJRT device config on the go. @will-cromar @yeounoh in cases you guys has some better suggestions.
@baoleai I remembered you guys mentioned something about SPMD + pp, wondering if you guys has some insight as well.
Currently, SPMD cannot support communication operators at the Python layer. When combining SPMD-TP and PP, we made numerous changes to xla and the openxla spmd pass to support send/recv @yitongh . Supporting the allreduce communication operator might be more complicated.
Based on previous experience, you will need to do the following things on GPU:
Even with the above handling, the all-reduce operator is currently not well-suited to handle sharded inputs and can only function as a replicated operation.
Similar handling may be required in the TPU environment. Overall, supporting Python-side communication in the SPMD environment doesn't seem to have any particularly elegant solutions at the moment. Perhaps, as JackCaoG suggested, changing the configuration of the PJRT device might be a good approach.
@baoleai @yitongh is the send/recv using XLA Send/Recv ? We are using all-reduce instead of send/recv to simplify our stack and we can assume that only non-sharded tensors will be passed in.
Can we use any way to skip sharding_propagation pass ? This can be an isolated graph (cut off using mark_Steps) and we can use custom_call or any attribute to skip "peeking" into the all-reduce
@JackCaoG For the cc ops set-up, why do we need ti set up PjRT in a different way? All we need is the graph with the correct replica groups correct? (this can be borrowed from mesh during SPMD set-up). The PjRT runtime would just execute this on all "threads" (we dont need these to be different processes ) and the all-reduce would look any other all-reduce from a SPMD partitioner pass.
🐛 Describe the bug
The crash seen is the following:
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
F0000 00:00:1709131242.311197 36940 hlo_sharding.cc:1034] Check failed: IsTuple() Check failure stack trace:
@ 0x7f1e46d752d9 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal() @ 0x7f1e40ed9700 xla::HloSharding::GetSubSharding() @ 0x7f1e41cadd35 xla::ShardingPropagation::InferShardingFromOperands() @ 0x7f1e41cb1cec xla::ShardingPropagation::Run()::{lambda()#3}::operator()() @ 0x7f1e41cb5d43 xla::ShardingPropagation::Run()
@ 0x7f1e41c98355 xla::HloPassPipeline::RunHelper()
@ 0x7f1e41c9933a xla::HloPassPipeline::RunPassesInternal<>()
@ 0x7f1e41c99fa4 xla::HloPassPipeline::Run()
@ 0x7f1e41100d49 neuron::HloOptimization() @ 0x7f1e410a3ab9 neuron::Optimize() @ 0x7f1e4109f07e neuron::PJRT_Client_Compile() @ 0x7f1e410a0638 neuron::Decorator<>::wrapper() @ 0x7f1e51d966c5 xla::InitializeArgsAndCompile() @ 0x7f1e51d969e0 xla::PjRtCApiClient::Compile() @ 0x7f1e4d3411e6 torch_xla::runtime::PjRtComputationClient::Compile() @ 0x7f1e4d14853e torch_xla::XLAGraphExecutor::Compile() @ 0x7f1e4d149f49 torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal() @ 0x7f1e4d14a58b torch_xla::XLAGraphExecutor::SyncTensorsGraph() @ 0x7f1e4d14a9b8 torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph() @ 0x7f1e4cf1928a torch_xla::(anonymous namespace)::StepMarker() @ 0x7f1e4cf196c6 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() @ 0x7f1e4cef6ed0 pybind11::cpp_function::dispatcher() @ 0x5d5499 PyCFunction_Call Aborted (core dumped)
A simple example to reproduce the bug is attached below:
Versions
Versions of relevant libraries: [pip3] numpy==1.24.4 [pip3] torch==2.1.2 [pip3] torch-neuronx==2.1.1.2.0.1b0 [pip3] torch-xla==2.1.2 [pip3] torchvision==0.16.2