openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
404 stars 112 forks source link

Why there is no ResizeOp/UpsampleOp in stablehlo dialect? #2604

Open FengJungle opened 1 week ago

FengJungle commented 1 week ago

Request description

I met a torch.aten.upsample op and it failed to be converted to stablehlo. I wonder why there is no resize/upsample op in stablohlo dialect?

abhigunj commented 1 week ago

Hello @FengJungle, thanks for creating this ticket. Could you provide more details on the conversion use case you are attempting here?

FengJungle commented 1 week ago

Hi @abhigunj , thanks for your reply. I'm trying to compile stable diffusion in torch-mlir, converting a torch.nn.Module to stablehlo. In the pass "LowerToBackendContractPass", I found there is a torch.aten.upsample_nearest2d.vec op.

This morning I read one discussion about extension of stablehlo opset: https://groups.google.com/a/openxla.org/g/openxla-discuss/c/kRe0B1mugZI. I think it is helpful for my current work.

qihqi commented 6 days ago

Hi @FengJungle,

would you post the repro script and stack trace? Thanks!

FengJungle commented 6 days ago

Hi @qihqi @abhigunj , please take a look:

import torch
from torch_mlir import torchscript

class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
        self.relu = torch.nn.ReLU()

    def forward(self, input):
        output = self.relu(input)
        output = self.upsample(output )
        return output

model = TestModel().eval()
input_shape = (1, 4, 8, 8)
data = torch.randn(input_shape, dtype=torch.float16)

out_file = "upsample.model.stablehlo.mlir"

with torch.no_grad():
    module = torchscript.compile(
        model, 
        data, 
        output_type="stablehlo", 
        use_tracing=True)
    with open(out_file, 'w', encoding='utf-8') as outf:
        outf.write(module.operation.get_asm(large_elements_limit=10))

image

GleasonK commented 5 days ago

Looks like a coverage issue in torch_mlir, the main reason we don't standardize higher level ops like this is that they tend to look slightly different across frameworks, making them difficult to standardize in a framework-independent way.

Also I'd recommend taking a look at torch_xla2 which has pretty good op coverage and is super hackable when needed (all python, leverages JAX), and supports this repro out-of-the-box:

import torch
import torch_xla2.export

class TestModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
        self.relu = torch.nn.ReLU()

    def forward(self, input):
        output = self.relu(input)
        output = self.upsample(output )
        return output

model = TestModel().eval()
input_shape = (1, 4, 8, 8)
data = (torch.randn(input_shape, dtype=torch.float16),)

with torch.no_grad():
    exported = torch.export.export(model, data)
    stablehlo = torch_xla2.export.exported_program_to_stablehlo(exported)
    print(stablehlo.mlir_module())
Script output:
module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<1x4x8x8xf16> {mhlo.layout_mode = "default"}) -> (tensor<1x4x16x16xf16> {jax.result_info = "[0]", mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<8> : tensor
    %c_0 = stablehlo.constant dense<0> : tensor
    %cst = stablehlo.constant dense<5.000000e-01> : tensor
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor
    %0 = call @relu(%arg0) : (tensor<1x4x8x8xf16>) -> tensor<1x4x8x8xf16>
    %1 = stablehlo.convert %0 : (tensor<1x4x8x8xf16>) -> tensor<1x4x8x8xf32>
    %2 = stablehlo.iota dim = 0 : tensor<16xf32>
    %3 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<16xf32>
    %4 = stablehlo.add %2, %3 : tensor<16xf32>
    %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16xf32>
    %6 = stablehlo.multiply %4, %5 : tensor<16xf32>
    %7 = stablehlo.convert %6 : (tensor<16xf32>) -> tensor<16xi32>
    %8 = stablehlo.broadcast_in_dim %7, dims = [0] : (tensor<16xi32>) -> tensor<16x1xi32>
    %9 = stablehlo.iota dim = 0 : tensor<16xf32>
    %10 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<16xf32>
    %11 = stablehlo.add %9, %10 : tensor<16xf32>
    %12 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<16xf32>
    %13 = stablehlo.multiply %11, %12 : tensor<16xf32>
    %14 = stablehlo.convert %13 : (tensor<16xf32>) -> tensor<16xi32>
    %15 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor<16x1xi32>
    %16 = stablehlo.compare  LT, %8, %15,  SIGNED : (tensor<16x1xi32>, tensor<16x1xi32>) -> tensor<16x1xi1>
    %17 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<16x1xi32>
    %18 = stablehlo.add %8, %17 : tensor<16x1xi32>
    %19 = stablehlo.select %16, %18, %8 : tensor<16x1xi1>, tensor<16x1xi32>
    %20 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor<16xi32>
    %21 = stablehlo.compare  LT, %14, %20,  SIGNED : (tensor<16xi32>, tensor<16xi32>) -> tensor<16xi1>
    %22 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<16xi32>
    %23 = stablehlo.add %14, %22 : tensor<16xi32>
    %24 = stablehlo.select %21, %23, %14 : tensor<16xi1>, tensor<16xi32>
    %25 = stablehlo.broadcast_in_dim %19, dims = [0, 1] : (tensor<16x1xi32>) -> tensor<16x16xi32>
    %26 = stablehlo.broadcast_in_dim %24, dims = [1] : (tensor<16xi32>) -> tensor<16x16xi32>
    %27 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<16x16xi32>) -> tensor<16x16x1xi32>
    %28 = stablehlo.broadcast_in_dim %26, dims = [0, 1] : (tensor<16x16xi32>) -> tensor<16x16x1xi32>
    %29 = stablehlo.concatenate %27, %28, dim = 2 : (tensor<16x16x1xi32>, tensor<16x16x1xi32>) -> tensor<16x16x2xi32>
    %30 = "stablehlo.gather"(%1, %29) <{dimension_numbers = #stablehlo.gather, slice_sizes = array}> : (tensor<1x4x8x8xf32>, tensor<16x16x2xi32>) -> tensor<1x4x16x16xf32>
    %31 = stablehlo.convert %30 : (tensor<1x4x16x16xf32>) -> tensor<1x4x16x16xf16>
    return %31 : tensor<1x4x16x16xf16>
  }
  func.func private @relu(%arg0: tensor<1x4x8x8xf16> {mhlo.layout_mode = "default"}) -> (tensor<1x4x8x8xf16> {mhlo.layout_mode = "default"}) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor
    %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<1x4x8x8xf16>
    %1 = stablehlo.maximum %arg0, %0 : tensor<1x4x8x8xf16>
    return %1 : tensor<1x4x8x8xf16>
  }
}
FengJungle commented 5 days ago

@GleasonK Hi, many thanks for you! It seems that I can replace torch.aten.upsample_nearest2d.vec with a combination of other stablehlo ops manually as the compile output by torch_xla2 you mentioned. Oh, that is fussy.....