pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

Dynamic batch dimension export tutorial (following https://openxla.org/stablehlo/tutorials/pytorch-export) segfaults (torch-xla==2.4.0) #8200

Open optiluca opened 2 weeks ago

optiluca commented 2 weeks ago

🐛 Bug

I am trying to follow https://openxla.org/stablehlo/tutorials/pytorch-export verbatim.

I have taken the official python:3.10 Docker image, and inside it I have installed:

--extra-index-url https://download.pytorch.org/whl/cpu
torch==2.4.0
torchvision==0.19.0
torchaudio==2.4.0
torch-xla==2.4.0
tensorflow-cpu==2.16.2

This call from the tutorial (under Export with dynamic batch dimension):

dynamic_stablehlo = exported_program_to_stablehlo(dynamic_export)

Throws:

F0000 00:00:1727773477.564067    2760 debug_macros.h:20] Non-OK-status: status.status()
Status: INVALID_ARGUMENT: Non-broadcast dimensions must not be dynamic.
*** Begin stack trace ***
        tsl::CurrentStackTrace()
        xla::Shape const* ConsumeValue<xla::Shape const*>(absl::lts_20230802::StatusOr<xla::Shape const*>&&)
        torch_xla::ShapeHelper::ShapeOfXlaOp(xla::XlaOp)

        torch_xla::BuildMaxPoolNd(xla::XlaOp, long, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, bool)
        torch_xla::MaxPoolNd::Lower(torch_xla::LoweringContext*) const
        torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
        torch_xla::LoweringContext::GetOutputOp(torch::lazy::Output const&)
        torch_xla::LoweringContext::AddResult(torch::lazy::Output const&)
        torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
        torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault

        PyEval_EvalCode

        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain

        __libc_start_main
        _start
*** End stack trace ***

*** Check failure stack trace: ***
    @     0x7effc17b15f9  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7effb9ce60e4  ConsumeValue<>()
    @     0x7effb9ce614e  torch_xla::ShapeHelper::ShapeOfXlaOp()
    @     0x7effb997bc03  torch_xla::(anonymous namespace)::ComputeMaxPoolIndices()
    @     0x7effb997cff1  torch_xla::BuildMaxPoolNd()
    @     0x7effb9c5fd76  torch_xla::MaxPoolNd::Lower()
    @     0x7effb9cdfcfd  torch_xla::LoweringContext::LowerNode()
    @     0x7effb9ce0683  torch_xla::LoweringContext::GetOutputOp()
    @     0x7effb9ce0981  torch_xla::LoweringContext::AddResult()
    @     0x7effb9975474  torch_xla::DumpUtil::ToHlo()
    @     0x7effb9ae31f6  torch_xla::XLAGraphExecutor::DumpHloComputation()
    @     0x7effb986a083  torch_xla::(anonymous namespace)::InitXlaModuleBindings()::{lambda()#67}::operator()()
    @     0x7effb988519f  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7effb986d1be  pybind11::cpp_function::dispatcher()
    @     0x7f020cefc4fd  cfunction_call
Aborted (core dumped)
JackCaoG commented 2 weeks ago

@lsy323 can you take a look since you are offcall this week?

lbortolotti commented 4 days ago

Any updates on this? I've also bumped into this issue.