Open pritamdamania87 opened 4 months ago
Hey @pritamdamania87 nice to see you here. Can you dump the full error message? I should crash with some C++/python stack traces. It wasn't clear to me from the current log where the crash is from.
Nice to see you as well @JackCaoG :) I updated the original issue summary with the stack trace as well.
Hmm from the log it seems like XLA:GPU failed when compiling the above HLO. This is really weird as the HLO is pretty straightforward. To sanity check, do you mind setting PJRT_DEVICE=CPU
and see if the same program will compile with XLA:CPU
?
PJRT_DEVICE=CPU
results in a segfault:
File "torch_xla/core/xla_model.py", line 1056 in mark_step
File "torch_xla/core/dynamo_bridge.py", line 542 in extract_compiled_graph
File "torch/_dynamo/backends/torchxla.py", line 51 in fwd
File "torch/_functorch/_aot_autograd/utils.py", line 89 in g
File "torch/_functorch/_aot_autograd/utils.py", line 113 in call_func_at_runtime_with_args
File "torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 505 in forward
File "torch/autograd/function.py", line 598 in apply
File "torch/_functorch/_aot_autograd/utils.py", line 89 in g
File "torch/_functorch/_aot_autograd/utils.py", line 113 in call_func_at_runtime_with_args
File "torch/_functorch/_aot_autograd/runtime_wrappers.py", line 88 in runtime_wrapper
File "torch/_functorch/_aot_autograd/utils.py", line 89 in g
File "torch/_functorch/aot_autograd.py", line 917 in forward
File "torch/_dynamo/external_utils.py", line 36 in inner
File "torch/_dynamo/eval_frame.py", line 451 in _fn
gdb backtrace:
#0 0x000015547b7bbd28 in llvm::FoldingSet<llvm::SDNode>::NodeEquals(llvm::FoldingSetBase const*, llvm::FoldingSetBase::Node*, llvm::FoldingSetNodeID const&, unsigned int, llvm::FoldingSetNodeID&) ()
from llvmlite/binding/../../../../libLLVM-14.so
#1 0x000015523b109e88 in llvm::FoldingSetBase::FindNodeOrInsertPos(llvm::FoldingSetNodeID const&, void*&, llvm::FoldingSetBase::FoldingSetInfo const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#2 0x0000155237280ec8 in llvm::SelectionDAG::FindNodeOrInsertPos(llvm::FoldingSetNodeID const&, llvm::SDLoc const&, void*&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#3 0x00001552372b4399 in llvm::SelectionDAG::getConstant(llvm::ConstantInt const&, llvm::SDLoc const&, llvm::EVT, bool, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#4 0x00001552372fba59 in llvm::SelectionDAGBuilder::visitGetElementPtr(llvm::User const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#5 0x0000155237326c20 in llvm::SelectionDAGBuilder::visit(llvm::Instruction const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#6 0x000015523733b953 in llvm::SelectionDAGISel::SelectBasicBlock(llvm::ilist_iterator_w_bits<llvm::ilist_detail::node_options<llvm::Instruction, false, false, void, true>, false, true>, llvm::ilist_iterator_w_bits<llvm::ilist_detail::node_options<llvm::Instruction, false, false, void, true>, false, true>, bool&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#7 0x000015523733db9f in llvm::SelectionDAGISel::SelectAllBasicBlocks(llvm::Function const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#8 0x000015523733fd01 in llvm::SelectionDAGISel::runOnMachineFunction(llvm::MachineFunction&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#9 0x00001552365e05c6 in (anonymous namespace)::X86DAGToDAGISel::runOnMachineFunction(llvm::MachineFunction&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#10 0x0000155236f239b4 in llvm::MachineFunctionPass::runOnFunction(llvm::Function&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#11 0x000015523a1b5b59 in llvm::FPPassManager::runOnFunction(llvm::Function&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#12 0x000015523a1b6003 in llvm::FPPassManager::runOnModule(llvm::Module&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#13 0x000015523a1b5500 in llvm::legacy::PassManagerImpl::run(llvm::Module&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#14 0x0000155234086144 in xla::cpu::CompilerFunctor::operator()(llvm::Module&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#15 0x00001552368d4388 in llvm::orc::IRCompileLayer::emit(std::unique_ptr<llvm::orc::MaterializationResponsibility, std::default_delete<llvm::orc::MaterializationResponsibility> >, llvm::orc::ThreadSafeModule) ()
from _XLAC.cpython-310-x86_64-linux-gnu.so
#16 0x00001552368faac5 in llvm::orc::BasicIRLayerMaterializationUnit::materialize(std::unique_ptr<llvm::orc::MaterializationResponsibility, std::default_delete<llvm::orc::MaterializationResponsibility> >) ()
from _XLAC.cpython-310-x86_64-linux-gnu.so
#17 0x000015523687ea9e in llvm::orc::MaterializationTask::run() () from _XLAC.cpython-310-x86_64-linux-gnu.so
#18 0x000015547ceec5c8 in void llvm::detail::UniqueFunctionBase<void, std::unique_ptr<llvm::orc::Task, std::default_delete<llvm::orc::Task> > >::CallImpl<void (*)(std::unique_ptr<llvm::orc::Task, std::default_delete<llvm::orc::Task> >)>(void*, std::unique_ptr<llvm::orc::--Type <RET> for more, q to quit, c to continue without paging--c
Task, std::default_delete<llvm::orc::Task> >&) () from llvmlite/binding/../../../../libLLVM-14.so
#19 0x000015523687f5e8 in llvm::orc::ExecutionSession::dispatchOutstandingMUs() () from _XLAC.cpython-310-x86_64-linux-gnu.so
#20 0x0000155236888d99 in llvm::orc::ExecutionSession::OL_completeLookup(std::unique_ptr<llvm::orc::InProgressLookupState, std::default_delete<llvm::orc::InProgressLookupState> >, std::shared_ptr<llvm::orc::AsynchronousSymbolQuery>, std::function<void (llvm::DenseMap<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> >, llvm::DenseMapInfo<llvm::orc::JITDylib*, void>, llvm::detail::DenseMapPair<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> > > > const&)>) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#21 0x000015523688990e in llvm::orc::InProgressFullLookupState::complete(std::unique_ptr<llvm::orc::InProgressLookupState, std::default_delete<llvm::orc::InProgressLookupState> >) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#22 0x000015523687aeb1 in llvm::orc::ExecutionSession::OL_applyQueryPhase1(std::unique_ptr<llvm::orc::InProgressLookupState, std::default_delete<llvm::orc::InProgressLookupState> >, llvm::Error) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#23 0x000015523687fb8a in llvm::orc::ExecutionSession::lookup(llvm::orc::LookupKind, std::vector<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags>, std::allocator<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags> > > const&, llvm::orc::SymbolLookupSet, llvm::orc::SymbolState, llvm::unique_function<void (llvm::Expected<llvm::DenseMap<llvm::orc::SymbolStringPtr, llvm::orc::ExecutorSymbolDef, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void>, llvm::detail::DenseMapPair<llvm::orc::SymbolStringPtr, llvm::orc::ExecutorSymbolDef> > >)>, std::function<void (llvm::DenseMap<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> >, llvm::DenseMapInfo<llvm::orc::JITDylib*, void>, llvm::detail::DenseMapPair<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> > > > const&)>) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#24 0x000015523688084a in llvm::orc::ExecutionSession::lookup(std::vector<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags>, std::allocator<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags> > > const&, llvm::orc::SymbolLookupSet, llvm::orc::LookupKind, llvm::orc::SymbolState, std::function<void (llvm::DenseMap<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> >, llvm::DenseMapInfo<llvm::orc::JITDylib*, void>, llvm::detail::DenseMapPair<llvm::orc::JITDylib*, llvm::DenseSet<llvm::orc::SymbolStringPtr, llvm::DenseMapInfo<llvm::orc::SymbolStringPtr, void> > > > const&)>) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#25 0x0000155236880d8a in llvm::orc::ExecutionSession::lookup(std::vector<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags>, std::allocator<std::pair<llvm::orc::JITDylib*, llvm::orc::JITDylibLookupFlags> > > const&, llvm::orc::SymbolStringPtr, llvm::orc::SymbolState) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#26 0x0000155236880fb4 in llvm::orc::ExecutionSession::lookup(llvm::ArrayRef<llvm::orc::JITDylib*>, llvm::orc::SymbolStringPtr, llvm::orc::SymbolState) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#27 0x00001552368811d1 in llvm::orc::ExecutionSession::lookup(llvm::ArrayRef<llvm::orc::JITDylib*>, llvm::StringRef, llvm::orc::SymbolState) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#28 0x0000155234080669 in xla::cpu::SimpleOrcJIT::FindCompiledSymbol(std::string const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#29 0x0000155234078781 in xla::cpu::CpuExecutable::Create(std::unique_ptr<xla::cpu::SimpleOrcJIT, std::default_delete<xla::cpu::SimpleOrcJIT> >, std::unique_ptr<xla::BufferAssignment const, std::default_delete<xla::BufferAssignment const> >, std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, std::string const&, std::unique_ptr<xla::HloProfilePrinterData, std::default_delete<xla::HloProfilePrinterData> >, std::unique_ptr<xla::HloProfileIndexMap, std::default_delete<xla::HloProfileIndexMap> >) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#30 0x0000155233aa2976 in xla::cpu::CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#31 0x0000155233aa303f in xla::cpu::CpuCompiler::RunBackend(std::unique_ptr<xla::HloModule, std::default_delete<xla::HloModule> >, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#32 0x0000155233a5d332 in xla::TfrtCpuClient::Compile(xla::XlaComputation const&, xla::CompileOptions) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#33 0x00001552317c1177 in torch_xla::runtime::PjRtComputationClient::Compile(std::vector<torch_xla::runtime::ComputationClient::CompileInstance, std::allocator<torch_xla::runtime::ComputationClient::CompileInstance> >) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#34 0x00001552315b7623 in torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >&, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#35 0x00001552315b953f in torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#36 0x00001552315b9b41 in torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20230802::Span<std::string const>, bool, bool, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#37 0x00001552315b9d78 in torch_xla::XLAGraphExecutor::SyncLiveTensorsGraph(torch::lazy::BackendDevice const*, c10::ArrayRef<std::string>, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#38 0x00001552313702ba in torch_xla::(anonymous namespace)::StepMarker(std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#39 0x0000155231370766 in pybind11::cpp_function::initialize<torch_xla::(anonymous namespace)::InitXlaModuleBindings(pybind11::module_)::{lambda(std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool)#65}, void, std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool, pybind11::name, pybind11::scope, pybind11::sibling, pybind11::arg_v, pybind11::arg, pybind11::arg_v>(torch_xla::(anonymous namespace)::InitXlaModuleBindings(pybind11::module_)::{lambda(std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool)#65}&&, void (*)(std::string const&, std::vector<std::string, std::allocator<std::string> > const&, bool), pybind11::name const&, pybind11::scope const&, pybind11::sibling const&, pybind11::arg_v const&, pybind11::arg const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call&) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#40 0x0000155231341fc6 in pybind11::cpp_function::dispatcher(_object*, _object*, _object*) () from _XLAC.cpython-310-x86_64-linux-gnu.so
#41 0x0000555555698516 in cfunction_call (func=0x1552476ae2f0, args=<optimized out>, kwargs=<optimized out>) at /usr/local/src/conda/python-3.10.12/Objects/methodobject.c:543
#42 0x0000555555691a6b in _PyObject_MakeTpCall (tstate=0x555555908e60, callable=0x1552476ae2f0, args=<optimized out>, nargs=2, keywords=0x1552476d8490) at /usr/local/src/conda/python-3.10.12/Objects/call.c:215
#43 0x000055555568dc39 in _PyObject_VectorcallTstate (kwnames=0x1552476d8490, nargsf=<optimized out>, args=<optimized out>, callable=0x1552476ae2f0, tstate=<optimized out>) at /usr/local/src/conda/python-3.10.12/Include/cpython/abstract.h:112
Could you share the model code you tried to torch.compile
so I can double check on my end?
Could you share the model code you tried to torch.compile so I can double check on my end?
Model is proprietary, was wondering if you have suggestions around building a minimal repro for this? There is a GraphModule section shared in the Issue summary, not sure if that can be used as-is for a repro.
well I have the fx graph and if you dump the xla_args
's input shape in https://github.com/pytorch/xla/blob/6f0b61e5d782913a0fc7743812f2a8e522189111/torch_xla/core/dynamo_bridge.py#L337-L339 I should be able to repo it. (otherwise I will have to manually map the HLO input shape back to tensor input shape and guess orders lol)
Thanks for the pointer, these are the shapes for xla_args
:
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([28, 4, 480, 640])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([1, 2, 1, 1])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([1])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([3, 3, 1, 1])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([3])
type(arg)=<class 'torch.nn.parameter.Parameter'>
arg.size()=torch.Size([3])
type(arg)=<class 'torch.nn.parameter.Parameter'>
arg.size()=torch.Size([3])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([3])
type(arg)=<class 'torch.Tensor'>
arg.size()=torch.Size([3])
I did try the following repro.py based on the shapes, but seems like it works without any issues.
filed an issue to XLA:GPU team, waiting for their feedback...
Does it work with xla-gpu-nightly? We couldn't repro the issue.
🐛 Bug
Running into the following error when using
torch.compile(backend="openxla")
:Graph Module:
I see the following files in the xla dump directory directory:
The
module_0001.SyncTensorsGraph.68.before_optimizations.txt
is as follows:This is with PyTorch 2.3 and torch-xla compatible with PyTorch 2.3. Was wondering how to debug this issue further?