google-ai-edge / ai-edge-torch

Supporting PyTorch models with the Google AI Edge TFLite runtime.
Apache License 2.0
310 stars 40 forks source link

Converting MaxPool2D with dynamic spatial dimensions crashes #270

Open sc-aharri opened 4 days ago

sc-aharri commented 4 days ago

Description of the bug:

When I convert a pytorch model containing a MaxPool2D module, ai_edge_torch.convert crashes. This can be reproduced on my setup using the following minimal repro:

import ai_edge_torch
import torch

if __name__ == "__main__":
    print("creating dummy image and model which is just a maxpool op")
    model = torch.nn.MaxPool2d(2)
    sample_image = torch.zeros((1, 1, 320, 320))

    print("define dynamic sizing, specifying that the sizes must be divisible by 2")
    height = 2 * torch.export.Dim("height", min=1, max=320)
    width = 2 * torch.export.Dim("width", min=1, max=320)
    dynamic_shapes = ({2: height, 3: width},)

    print("converting model")
    exported_model = ai_edge_torch.convert(
        model,
        (sample_image,),
        dynamic_shapes=dynamic_shapes,
    )

The relevant stack trace can be found below.

Actual vs expected behavior:

The script crashes with the following callstack:

2024-09-30 13:05:37.114853: E external/xla/xla/status_macros.cc:56] INTERNAL: RET_CHECK failure (external/xla/xla/service/shape_inference.cc:3508) !inferred_shape.is_unbounded_dynamic() Reshaping with unbounded result shape is not supported.
*** Begin stack trace ***
        tsl::CurrentStackTrace()

        xla::status_macros::MakeErrorStream::Impl::GetStatus()
        xla::ShapeInference::InferReshapeShape(xla::Shape const&, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, long)

        xla::XlaBuilder::ReportErrorOrReturn(absl::lts_20230802::FunctionRef<absl::lts_20230802::StatusOr<xla::XlaOp> ()>)
        xla::XlaBuilder::Reshape(xla::XlaOp, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, long)

        xla::XlaBuilder::ReportErrorOrReturn(absl::lts_20230802::FunctionRef<absl::lts_20230802::StatusOr<xla::XlaOp> ()>)
        xla::XlaBuilder::Reshape(xla::XlaOp, absl::lts_20230802::Span<long const>, long)

        torch_xla::BuildMaxPoolNd(xla::XlaOp, long, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, absl::lts_20230802::Span<long const>, bool)
        torch_xla::MaxPoolNd::Lower(torch_xla::LoweringContext*) const
        torch_xla::LoweringContext::LowerNode(torch::lazy::Node const*)
        torch_xla::LoweringContext::GetOutputOp(torch::lazy::Output const&)
        torch_xla::LoweringContext::AddResult(torch::lazy::Output const&)
        torch_xla::DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value>, torch::lazy::BackendDevice const&, torch_xla::EmitMode)
        torch_xla::XLAGraphExecutor::DumpHloComputation(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, torch_xla::EmitMode)

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        PyEval_EvalCode

        _PyRun_SimpleFileObject
        _PyRun_AnyFileObject
        Py_RunMain
        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

Traceback (most recent call last):
  File "basic-repro.py", line 13, in <module>
    exported_model = ai_edge_torch.convert(
                     ^^^^^^^^^^^^^^^^^^^^^^
  File ".../ai_edge_torch/convert/converter.py", line 195, in convert
    return Converter().convert(
           ^^^^^^^^^^^^^^^^^^^^
  File ".../ai_edge_torch/convert/converter.py", line 134, in convert
    return conversion.convert_signatures(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../ai_edge_torch/convert/conversion.py", line 97, in convert_signatures
    shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
                                                         ^
  File ".../ai_edge_torch/convert/conversion.py", line 98, in <listcomp>
    cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
  File ".../ai_edge_torch/convert/conversion_utils.py", line 133, in exported_program_to_stablehlo_bundle
    return stablehlo.exported_program_to_stablehlo(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../torch_xla/stablehlo.py", line 618, in exported_program_to_stablehlo
    bundle = _exported_program_to_stablehlo_bundle(exported_model, options)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../torch_xla/stablehlo.py", line 397, in _exported_program_to_stablehlo_bundle
    stablehlo_content = xm.get_stablehlo_bytecode(res)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../torch_xla/core/xla_model.py", line 1112, in get_stablehlo_bytecode
    return torch_xla._XLAC._get_stablehlo(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Error while lowering: [] aten::max_pool2d, num_outputs=2, xla_shape=(f32[1,1,160,160]{3,2,1,0}, u32[1,1,160,160]{3,2,1,0}), dynamic_dims: (), spatial_dim_count=2, kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=0
XLA builder error: INTERNAL: RET_CHECK failure (external/xla/xla/service/shape_inference.cc:3508) !inferred_shape.is_unbounded_dynamic() Reshaping with unbounded result shape is not supported.: 

It seems thatis_unbounded_dynamic is returning false. There are cetrtainly dynamic dimensions, but I'm suspicious of the word unbounded, since I am defining bounds. Should I expect bounds to be propagated to torch_xla?

Any other information you'd like to share?

My setup is an Ubuntu docker container with cpu-only versions of torch.

pkgoogle commented 4 days ago

Hi @sc-aharri, if you are able to staticize as much as you can, that'll probably be the easiest work around for now. I was able to reproduce on nightly with the exact same code... it might be at root a torch_xla issue, assuming we are passing in the right values here: torch_xla_utils.py