openxla / stablehlo

Backward compatible ML compute opset inspired by HLO/MHLO
Apache License 2.0
357 stars 100 forks source link

Request ConvertToSignlessPass in Stablehlo #2356

Closed qingyunqu closed 1 month ago

qingyunqu commented 1 month ago

Request description

When I use stablehlo-to-linalg in torch-mlir(https://github.com/llvm/torch-mlir/pull/3367), I find that it will emit builtin.unrealized_conversion_cast in the following case:

module attributes {torch.debug_module_name = "Uint8ToFloat"} {
  func.func @forward(%arg0: tensor<3xui8>) -> tensor<3xf32> {
    %0 = stablehlo.convert %arg0 : (tensor<3xui8>) -> tensor<3xf32>
    return %0 : tensor<3xf32>
  }
}

#map = affine_map<(d0) -> (d0)>
module attributes {torch.debug_module_name = "Uint8ToFloat"} {
  func.func @forward(%arg0: tensor<3xui8>) -> tensor<3xf32> {
    %0 = builtin.unrealized_conversion_cast %arg0 : tensor<3xui8> to tensor<3xi8>
    %1 = tensor.empty() : tensor<3xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%0 : tensor<3xi8>) outs(%1 : tensor<3xf32>) {
    ^bb0(%in: i8, %out: f32):
      %3 = arith.uitofp %in : i8 to f32
      linalg.yield %3 : f32
    } -> tensor<3xf32>
    return %2 : tensor<3xf32>
  }
}

I think I need mhlo's ConvertToSignlessPass to eliminate the builtin.unrealized_conversion_cast.

GleasonK commented 1 month ago

Hello! If you have time to migrate that pass to StableHLO, we're happy to house it. Looking at the implementation, there is nothing specific to MHLO about that pass so it likely can be shared by both StableHLO/MHLO.