nod-ai / SHARK-Turbine

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
91 stars 45 forks source link

torch.aten.avg_pool2d to linalg #643

Closed AmosLewis closed 2 months ago

AmosLewis commented 5 months ago

Find this failed in Inception_v4_vaiq_int8 model support https://github.com/nod-ai/SHARK-TestSuite/issues/190

Inception_v4_vaiq_int8.default.onnx.torch.mlir:1053:12: error: failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal
    %816 = torch.aten.avg_pool2d %656, %813, %815, %814, %false_224, %false_225, %none_226 : !torch.vtensor<[32,384,25,25],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[32,384,25,25],f32>

%446 = torch.operator "onnx.AveragePool"(%403) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[32,384,25,25],f32>) -> !torch.vtensor<[32,384,25,25],f32>

Previouse related patch: torch-to-linalg [MLIR][TORCH] Add E2E support for aten.avg_pool2d op [Stablehlo]Add support for AvgPool1dOp [RFC] general support for Adaptive Pooling Ops onnx-to-torch: [MLIR][ONNX] Add OnnxToTorch support for AveragePool op

count_include_pad=True/False explaned

AmosLewis commented 5 months ago

Tried to add an new e2e tests for this case:

class AvgPool2dFloatStaticModule(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.ap2d = torch.nn.AvgPool2d(kernel_size=[3, 3],
                                       stride=[1, 1],
                                       padding=[1, 1, 1, 1],
                                       ceil_mode=False,
                                       count_include_pad=False,
                                       divisor_override=None)

    @export
    @annotate_args([
        None,
        ([32, 384, 25, 25], torch.float32, True),
    ])
    def forward(self, x):
        return self.ap2d(x)

@register_test_case(module_factory=lambda: AvgPool2dFloatStaticModule())
def AvgPool2dFloatStaticModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(32, 384, 25, 25, low=-1))

Run: python -m e2e_testing.main --config=linalg --filter AvgPool2dFloatStaticModule -v Got:

TORCH_VERSION_FOR_COMPARISON = 2.4.0.dev20240416
FAIL - "AvgPool2dFloatStaticModule_basic"

Unexpected outcome summary: (linalg)

****** Failed tests - 1 tests
    FAIL - "AvgPool2dFloatStaticModule_basic"
        Compilation error: Traceback (most recent call last):
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/framework.py", line 295, in compile_and_run_test
            golden_trace = generate_golden_trace(test)
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/framework.py", line 289, in generate_golden_trace
            test.program_invoker(tracer, TestUtils())
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/pooling.py", line 867, in AvgPool2dFloatStaticModule_basic
            module.forward(tu.rand(32, 384, 25, 25, low=-1))
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/framework.py", line 269, in __call__
            output = self.__wrapped__(*args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
          File "/home/chi/src/torch-mlir/build/tools/torch-mlir/python_packages/torch_mlir/torch_mlir_e2e_test/test_suite/pooling.py", line 862, in forward
            return self.ap2d(x)
                   ^^^^^^^^^^^^
          File "/home/chi/src/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
            return self._call_impl(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
          File "/home/chi/src/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
            return forward_call(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
          File "/home/chi/src/torch-mlir/mlir_venv/lib/python3.11/site-packages/torch/nn/modules/pooling.py", line 641, in forward
            return F.avg_pool2d(input, self.kernel_size, self.stride,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        RuntimeError: avg_pool2d: padding must either be a single int, or a tuple of two ints

Summary:
    Failed: 1
AmosLewis commented 3 months ago

https://github.com/llvm/torch-mlir/issues/3428