NVIDIA / TensorRT-Incubator

Experimental projects related to TensorRT
72 stars 12 forks source link

[Bug] failed to derive upper bound #269

Open yizhuoz004 opened 5 days ago

yizhuoz004 commented 5 days ago
module @outs_t549_2 {
  func.func @main() -> tensor<?x?x?x?xf32> {
    %c = stablehlo.constant dense<[1, 144, 7, 7]> : tensor<4xi32>
    %0 = stablehlo.convert %c : (tensor<4xi32>) -> tensor<?xi32>
    %1 = call @Fill(%0) : (tensor<?xi32>) -> tensor<?x?x?x?xf32>
    %c_0 = stablehlo.constant dense<[1, 144, 256, 256]> : tensor<4xi32>
    %2 = tensorrt.resize_cubic {coordinateTransformation = #tensorrt.resize_coordinate_transformation<kHALF_PIXEL>, cubicCoeff = -7.500000e-01 : f32, selectorForSinglePixel = #tensorrt.resize_selector<kFORMULA>} %1, %c_0 : (tensor<?x?x?x?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
    return %2 : tensor<?x?x?x?xf32>
  }
  func.func private @Fill(%arg0: tensor<?xi32>) -> tensor<?x?x?x?xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.convert %arg0 : (tensor<?xi32>) -> tensor<4xi32>
    %1 = stablehlo.dynamic_broadcast_in_dim %cst, %0, dims = [] : (tensor<f32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
    return %1 : tensor<?x?x?x?xf32>
  }
}
error: failed to derive upper bound for %2 = "tensor.cast"(%1) : (tensor<1x144x7x7xf32>) -> tensor<?x?x?x?xf32>
yizhuoz004 commented 4 days ago
module @ins_x0_x1_x2_x3_outs_t235_t214_t193_t112_t80_t81_t82_t83_9 {
  func.func @main(%arg0: tensor<1x144x256x256xf32>, %arg1: tensor<1x288x128x128xf32>, %arg2: tensor<1x576x64x64xf32>, %arg3: tensor<1x1152x32x32xf32>) -> (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<1x256x256x256xf32>, tensor<1x256x128x128xf32>, tensor<1x256x64x64xf32>, tensor<1x256x32x32xf32>) {
    %c = stablehlo.constant dense<36864> : tensor<1xi32>
    %0 = stablehlo.dynamic_iota %c, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
    %1 = stablehlo.convert %c : (tensor<1xi32>) -> tensor<?xi32>
    %2 = call @Fill(%1) : (tensor<?xi32>) -> tensor<?xf32>
    %3 = call @BinaryElementwise(%0, %2) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %4 = stablehlo.convert %c : (tensor<1xi32>) -> tensor<?xi32>
    %5 = call @Fill_1(%4) : (tensor<?xi32>) -> tensor<?xf32>
    %6 = call @BinaryElementwise_1(%3, %5) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %c_0 = stablehlo.constant dense<[256, 144, 1, 1]> : tensor<4xi32>
    %7 = stablehlo.convert %c_0 : (tensor<4xi32>) -> tensor<?xi32>
    %8 = call @Reshape(%6, %7) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
    %9 = stablehlo.convolution(%arg0, %8) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x144x256x256xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
    %c_1 = stablehlo.constant dense<256> : tensor<1xi32>
    %10 = stablehlo.dynamic_iota %c_1, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
    %11 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %12 = call @Fill(%11) : (tensor<?xi32>) -> tensor<?xf32>
    %13 = call @BinaryElementwise(%10, %12) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %14 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %15 = call @Fill_1(%14) : (tensor<?xi32>) -> tensor<?xf32>
    %16 = call @BinaryElementwise_1(%13, %15) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %17 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %18 = call @Reshape_1(%16, %17) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?xf32>
    %c_2 = stablehlo.constant dense<1> : tensor<1xi32>
    %19 = stablehlo.get_dimension_size %8, dim = 0 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %20 = stablehlo.reshape %19 : (tensor<i32>) -> tensor<1xi32>
    %21 = stablehlo.get_dimension_size %8, dim = 1 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %22 = stablehlo.reshape %21 : (tensor<i32>) -> tensor<1xi32>
    %23 = stablehlo.get_dimension_size %8, dim = 2 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %24 = stablehlo.reshape %23 : (tensor<i32>) -> tensor<1xi32>
    %25 = stablehlo.get_dimension_size %8, dim = 3 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %26 = stablehlo.reshape %25 : (tensor<i32>) -> tensor<1xi32>
    %27 = stablehlo.concatenate %20, %22, %24, %26, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %c_3 = stablehlo.constant dense<0> : tensor<i32>
    %c_4 = stablehlo.constant dense<1> : tensor<i32>
    %28 = stablehlo.convert %27 : (tensor<4xi32>) -> tensor<?xi32>
    %29 = call @Slice(%28, %c_3, %c_4, %c_4) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %c_5 = stablehlo.constant dense<> : tensor<0xi32>
    %30 = stablehlo.dynamic_reshape %29, %c_5 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %31 = stablehlo.dynamic_broadcast_in_dim %30, %c_2, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %c_6 = stablehlo.constant dense<1> : tensor<2xi32>
    %32 = stablehlo.concatenate %c_2, %31, %c_6, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
    %33 = stablehlo.convert %32 : (tensor<4xi32>) -> tensor<?xi32>
    %34 = call @Reshape_2(%18, %33) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
    %35 = call @BinaryElementwise_2(%9, %34) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
    %c_7 = stablehlo.constant dense<73728> : tensor<1xi32>
    %36 = stablehlo.dynamic_iota %c_7, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
    %37 = stablehlo.convert %c_7 : (tensor<1xi32>) -> tensor<?xi32>
    %38 = call @Fill(%37) : (tensor<?xi32>) -> tensor<?xf32>
    %39 = call @BinaryElementwise(%36, %38) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %40 = stablehlo.convert %c_7 : (tensor<1xi32>) -> tensor<?xi32>
    %41 = call @Fill_1(%40) : (tensor<?xi32>) -> tensor<?xf32>
    %42 = call @BinaryElementwise_1(%39, %41) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %c_8 = stablehlo.constant dense<[256, 288, 1, 1]> : tensor<4xi32>
    %43 = stablehlo.convert %c_8 : (tensor<4xi32>) -> tensor<?xi32>
    %44 = call @Reshape(%42, %43) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
    %45 = stablehlo.convolution(%arg1, %44) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x288x128x128xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
    %46 = stablehlo.dynamic_iota %c_1, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
    %47 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %48 = call @Fill(%47) : (tensor<?xi32>) -> tensor<?xf32>
    %49 = call @BinaryElementwise(%46, %48) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %50 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %51 = call @Fill_1(%50) : (tensor<?xi32>) -> tensor<?xf32>
    %52 = call @BinaryElementwise_1(%49, %51) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %53 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %54 = call @Reshape_1(%52, %53) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?xf32>
    %55 = stablehlo.get_dimension_size %44, dim = 0 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %56 = stablehlo.reshape %55 : (tensor<i32>) -> tensor<1xi32>
    %57 = stablehlo.get_dimension_size %44, dim = 1 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %58 = stablehlo.reshape %57 : (tensor<i32>) -> tensor<1xi32>
    %59 = stablehlo.get_dimension_size %44, dim = 2 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %60 = stablehlo.reshape %59 : (tensor<i32>) -> tensor<1xi32>
    %61 = stablehlo.get_dimension_size %44, dim = 3 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %62 = stablehlo.reshape %61 : (tensor<i32>) -> tensor<1xi32>
    %63 = stablehlo.concatenate %56, %58, %60, %62, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %64 = stablehlo.convert %63 : (tensor<4xi32>) -> tensor<?xi32>
    %65 = call @Slice(%64, %c_3, %c_4, %c_4) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %66 = stablehlo.dynamic_reshape %65, %c_5 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %67 = stablehlo.dynamic_broadcast_in_dim %66, %c_2, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %68 = stablehlo.concatenate %c_2, %67, %c_6, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
    %69 = stablehlo.convert %68 : (tensor<4xi32>) -> tensor<?xi32>
    %70 = call @Reshape_2(%54, %69) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
    %71 = call @BinaryElementwise_2(%45, %70) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
    %c_9 = stablehlo.constant dense<147456> : tensor<1xi32>
    %72 = stablehlo.dynamic_iota %c_9, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
    %73 = stablehlo.convert %c_9 : (tensor<1xi32>) -> tensor<?xi32>
    %74 = call @Fill(%73) : (tensor<?xi32>) -> tensor<?xf32>
    %75 = call @BinaryElementwise(%72, %74) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %76 = stablehlo.convert %c_9 : (tensor<1xi32>) -> tensor<?xi32>
    %77 = call @Fill_1(%76) : (tensor<?xi32>) -> tensor<?xf32>
    %78 = call @BinaryElementwise_1(%75, %77) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %c_10 = stablehlo.constant dense<[256, 576, 1, 1]> : tensor<4xi32>
    %79 = stablehlo.convert %c_10 : (tensor<4xi32>) -> tensor<?xi32>
    %80 = call @Reshape(%78, %79) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
    %81 = stablehlo.convolution(%arg2, %80) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x576x64x64xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
    %82 = stablehlo.dynamic_iota %c_1, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
    %83 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %84 = call @Fill(%83) : (tensor<?xi32>) -> tensor<?xf32>
    %85 = call @BinaryElementwise(%82, %84) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %86 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %87 = call @Fill_1(%86) : (tensor<?xi32>) -> tensor<?xf32>
    %88 = call @BinaryElementwise_1(%85, %87) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %89 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %90 = call @Reshape_1(%88, %89) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?xf32>
    %91 = stablehlo.get_dimension_size %80, dim = 0 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %92 = stablehlo.reshape %91 : (tensor<i32>) -> tensor<1xi32>
    %93 = stablehlo.get_dimension_size %80, dim = 1 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %94 = stablehlo.reshape %93 : (tensor<i32>) -> tensor<1xi32>
    %95 = stablehlo.get_dimension_size %80, dim = 2 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %96 = stablehlo.reshape %95 : (tensor<i32>) -> tensor<1xi32>
    %97 = stablehlo.get_dimension_size %80, dim = 3 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %98 = stablehlo.reshape %97 : (tensor<i32>) -> tensor<1xi32>
    %99 = stablehlo.concatenate %92, %94, %96, %98, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %100 = stablehlo.convert %99 : (tensor<4xi32>) -> tensor<?xi32>
    %101 = call @Slice(%100, %c_3, %c_4, %c_4) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %102 = stablehlo.dynamic_reshape %101, %c_5 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %103 = stablehlo.dynamic_broadcast_in_dim %102, %c_2, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %104 = stablehlo.concatenate %c_2, %103, %c_6, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
    %105 = stablehlo.convert %104 : (tensor<4xi32>) -> tensor<?xi32>
    %106 = call @Reshape_2(%90, %105) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
    %107 = call @BinaryElementwise_2(%81, %106) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
    %c_11 = stablehlo.constant dense<294912> : tensor<1xi32>
    %108 = stablehlo.dynamic_iota %c_11, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
    %109 = stablehlo.convert %c_11 : (tensor<1xi32>) -> tensor<?xi32>
    %110 = call @Fill(%109) : (tensor<?xi32>) -> tensor<?xf32>
    %111 = call @BinaryElementwise(%108, %110) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %112 = stablehlo.convert %c_11 : (tensor<1xi32>) -> tensor<?xi32>
    %113 = call @Fill_1(%112) : (tensor<?xi32>) -> tensor<?xf32>
    %114 = call @BinaryElementwise_1(%111, %113) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %c_12 = stablehlo.constant dense<[256, 1152, 1, 1]> : tensor<4xi32>
    %115 = stablehlo.convert %c_12 : (tensor<4xi32>) -> tensor<?xi32>
    %116 = call @Reshape(%114, %115) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
    %117 = stablehlo.convolution(%arg3, %116) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x1152x32x32xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
    %118 = stablehlo.dynamic_iota %c_1, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
    %119 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %120 = call @Fill(%119) : (tensor<?xi32>) -> tensor<?xf32>
    %121 = call @BinaryElementwise(%118, %120) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %122 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %123 = call @Fill_1(%122) : (tensor<?xi32>) -> tensor<?xf32>
    %124 = call @BinaryElementwise_1(%121, %123) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
    %125 = stablehlo.convert %c_1 : (tensor<1xi32>) -> tensor<?xi32>
    %126 = call @Reshape_1(%124, %125) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?xf32>
    %127 = stablehlo.get_dimension_size %116, dim = 0 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %128 = stablehlo.reshape %127 : (tensor<i32>) -> tensor<1xi32>
    %129 = stablehlo.get_dimension_size %116, dim = 1 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %130 = stablehlo.reshape %129 : (tensor<i32>) -> tensor<1xi32>
    %131 = stablehlo.get_dimension_size %116, dim = 2 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %132 = stablehlo.reshape %131 : (tensor<i32>) -> tensor<1xi32>
    %133 = stablehlo.get_dimension_size %116, dim = 3 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %134 = stablehlo.reshape %133 : (tensor<i32>) -> tensor<1xi32>
    %135 = stablehlo.concatenate %128, %130, %132, %134, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %136 = stablehlo.convert %135 : (tensor<4xi32>) -> tensor<?xi32>
    %137 = call @Slice(%136, %c_3, %c_4, %c_4) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %138 = stablehlo.dynamic_reshape %137, %c_5 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %139 = stablehlo.dynamic_broadcast_in_dim %138, %c_2, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %140 = stablehlo.concatenate %c_2, %139, %c_6, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
    %141 = stablehlo.convert %140 : (tensor<4xi32>) -> tensor<?xi32>
    %142 = call @Reshape_2(%126, %141) : (tensor<?xf32>, tensor<?xi32>) -> tensor<?x?x?x?xf32>
    %143 = call @BinaryElementwise_2(%117, %142) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
    %144 = stablehlo.get_dimension_size %143, dim = 0 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %145 = stablehlo.reshape %144 : (tensor<i32>) -> tensor<1xi32>
    %146 = stablehlo.get_dimension_size %143, dim = 1 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %147 = stablehlo.reshape %146 : (tensor<i32>) -> tensor<1xi32>
    %148 = stablehlo.get_dimension_size %143, dim = 2 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %149 = stablehlo.reshape %148 : (tensor<i32>) -> tensor<1xi32>
    %150 = stablehlo.get_dimension_size %143, dim = 3 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %151 = stablehlo.reshape %150 : (tensor<i32>) -> tensor<1xi32>
    %152 = stablehlo.concatenate %145, %147, %149, %151, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %153 = stablehlo.convert %152 : (tensor<4xi32>) -> tensor<?xi32>
    %154 = call @Slice(%153, %c_3, %c_4, %c_4) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %155 = stablehlo.dynamic_reshape %154, %c_5 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %156 = stablehlo.dynamic_broadcast_in_dim %155, %c_2, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %c_13 = stablehlo.constant dense<2> : tensor<i32>
    %157 = stablehlo.convert %152 : (tensor<4xi32>) -> tensor<?xi32>
    %158 = call @Slice(%157, %c_4, %c_13, %c_4) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %159 = stablehlo.dynamic_reshape %158, %c_5 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %160 = stablehlo.dynamic_broadcast_in_dim %159, %c_2, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %c_14 = stablehlo.constant dense<3> : tensor<i32>
    %161 = stablehlo.convert %152 : (tensor<4xi32>) -> tensor<?xi32>
    %162 = call @Slice(%161, %c_13, %c_14, %c_4) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %163 = stablehlo.dynamic_reshape %162, %c_5 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %164 = call @BinaryElementwise_3(%163, %c_13) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %165 = stablehlo.dynamic_broadcast_in_dim %164, %c_2, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %c_15 = stablehlo.constant dense<4> : tensor<i32>
    %166 = stablehlo.convert %152 : (tensor<4xi32>) -> tensor<?xi32>
    %167 = call @Slice(%166, %c_14, %c_15, %c_4) : (tensor<?xi32>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
    %168 = stablehlo.dynamic_reshape %167, %c_5 : (tensor<?xi32>, tensor<0xi32>) -> tensor<i32>
    %169 = call @BinaryElementwise_3(%168, %c_13) : (tensor<i32>, tensor<i32>) -> tensor<i32>
    %170 = stablehlo.dynamic_broadcast_in_dim %169, %c_2, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %171 = stablehlo.concatenate %156, %160, %165, %170, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %172 = tensorrt.resize_nearest {coordinateTransformation = #tensorrt.resize_coordinate_transformation<kASYMMETRIC>, nearestRounding = #tensorrt.resize_round_mode<kFLOOR>, selectorForSinglePixel = #tensorrt.resize_selector<kFORMULA>} %143, %171 : (tensor<?x?x?x?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
    %173 = call @BinaryElementwise_2(%107, %172) : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
    %cst = stablehlo.constant dense_resource<__elided__> : tensor<1x256x256x256xf32>
    %cst_16 = stablehlo.constant dense_resource<__elided__> : tensor<1x256x128x128xf32>
    %cst_17 = stablehlo.constant dense_resource<__elided__> : tensor<1x256x64x64xf32>
    %cst_18 = stablehlo.constant dense_resource<__elided__> : tensor<1x256x32x32xf32>
    return %35, %71, %173, %143, %cst, %cst_16, %cst_17, %cst_18 : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>, tensor<1x256x256x256xf32>, tensor<1x256x128x128xf32>, tensor<1x256x64x64xf32>, tensor<1x256x32x32xf32>
  }
  func.func private @Fill(%arg0: tensor<?xi32>) -> tensor<?xf32> {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %0 = stablehlo.convert %arg0 : (tensor<?xi32>) -> tensor<1xi32>
    %1 = stablehlo.dynamic_broadcast_in_dim %cst, %0, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
    return %1 : tensor<?xf32>
  }
  func.func private @BinaryElementwise(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
    %c = stablehlo.constant dense<1> : tensor<1xi32>
    %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
    %2 = stablehlo.compare  EQ, %1, %c : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %3 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor<?xf32>) -> tensor<i32>
    %4 = stablehlo.reshape %3 : (tensor<i32>) -> tensor<1xi32>
    %5 = stablehlo.select %2, %4, %1 : tensor<1xi1>, tensor<1xi32>
    %6 = stablehlo.dynamic_broadcast_in_dim %arg0, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    %7 = stablehlo.dynamic_broadcast_in_dim %arg1, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    %8 = stablehlo.multiply %6, %7 : tensor<?xf32>
    return %8 : tensor<?xf32>
  }
  func.func private @Fill_1(%arg0: tensor<?xi32>) -> tensor<?xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.convert %arg0 : (tensor<?xi32>) -> tensor<1xi32>
    %1 = stablehlo.dynamic_broadcast_in_dim %cst, %0, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
    return %1 : tensor<?xf32>
  }
  func.func private @BinaryElementwise_1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?xf32>) -> tensor<i32>
    %c = stablehlo.constant dense<1> : tensor<1xi32>
    %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
    %2 = stablehlo.compare  EQ, %1, %c : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %3 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor<?xf32>) -> tensor<i32>
    %4 = stablehlo.reshape %3 : (tensor<i32>) -> tensor<1xi32>
    %5 = stablehlo.select %2, %4, %1 : tensor<1xi1>, tensor<1xi32>
    %6 = stablehlo.dynamic_broadcast_in_dim %arg0, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    %7 = stablehlo.dynamic_broadcast_in_dim %arg1, %5, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    %8 = stablehlo.add %6, %7 : tensor<?xf32>
    return %8 : tensor<?xf32>
  }
  func.func private @Reshape(%arg0: tensor<?xf32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<?xi32>) -> tensor<4xi32>
    %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
    return %1 : tensor<?x?x?x?xf32>
  }
  func.func private @Reshape_1(%arg0: tensor<?xf32>, %arg1: tensor<?xi32>) -> tensor<?xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<?xi32>) -> tensor<1xi32>
    %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    return %1 : tensor<?xf32>
  }
  func.func private @Slice(%arg0: tensor<?xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<?xi32> {
    %c = stablehlo.constant dense<1> : tensor<1xi32>
    %0 = stablehlo.reshape %arg1 : (tensor<i32>) -> tensor<1xi32>
    %1 = stablehlo.reshape %arg2 : (tensor<i32>) -> tensor<1xi32>
    %2 = stablehlo.compare  LE, %0, %1 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %3 = stablehlo.select %2, %0, %1 : tensor<1xi1>, tensor<1xi32>
    %4 = stablehlo.reshape %arg3 : (tensor<i32>) -> tensor<1xi32>
    %5 = stablehlo.real_dynamic_slice %arg0, %3, %1, %4 : (tensor<?xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
    return %5 : tensor<?xi32>
  }
  func.func private @Reshape_2(%arg0: tensor<?xf32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xf32> {
    %0 = stablehlo.convert %arg1 : (tensor<?xi32>) -> tensor<4xi32>
    %1 = stablehlo.dynamic_reshape %arg0, %0 : (tensor<?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
    return %1 : tensor<?x?x?x?xf32>
  }
  func.func private @BinaryElementwise_2(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32> {
    %0 = stablehlo.get_dimension_size %arg0, dim = 0 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %c = stablehlo.constant dense<1> : tensor<1xi32>
    %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
    %2 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32>
    %4 = stablehlo.get_dimension_size %arg0, dim = 2 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %5 = stablehlo.reshape %4 : (tensor<i32>) -> tensor<1xi32>
    %6 = stablehlo.get_dimension_size %arg0, dim = 3 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %7 = stablehlo.reshape %6 : (tensor<i32>) -> tensor<1xi32>
    %8 = stablehlo.concatenate %1, %3, %5, %7, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %c_0 = stablehlo.constant dense<1> : tensor<4xi32>
    %9 = stablehlo.compare  EQ, %8, %c_0 : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
    %10 = stablehlo.get_dimension_size %arg1, dim = 0 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %11 = stablehlo.reshape %10 : (tensor<i32>) -> tensor<1xi32>
    %12 = stablehlo.get_dimension_size %arg1, dim = 1 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %13 = stablehlo.reshape %12 : (tensor<i32>) -> tensor<1xi32>
    %14 = stablehlo.get_dimension_size %arg1, dim = 2 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %15 = stablehlo.reshape %14 : (tensor<i32>) -> tensor<1xi32>
    %16 = stablehlo.get_dimension_size %arg1, dim = 3 : (tensor<?x?x?x?xf32>) -> tensor<i32>
    %17 = stablehlo.reshape %16 : (tensor<i32>) -> tensor<1xi32>
    %18 = stablehlo.concatenate %11, %13, %15, %17, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32>
    %19 = stablehlo.select %9, %18, %8 : tensor<4xi1>, tensor<4xi32>
    %20 = stablehlo.dynamic_broadcast_in_dim %arg0, %19, dims = [0, 1, 2, 3] : (tensor<?x?x?x?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
    %21 = stablehlo.dynamic_broadcast_in_dim %arg1, %19, dims = [0, 1, 2, 3] : (tensor<?x?x?x?xf32>, tensor<4xi32>) -> tensor<?x?x?x?xf32>
    %22 = stablehlo.add %20, %21 : tensor<?x?x?x?xf32>
    return %22 : tensor<?x?x?x?xf32>
  }
  func.func private @BinaryElementwise_3(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
    %c = stablehlo.constant dense<> : tensor<0xi32>
    %c_0 = stablehlo.constant dense<> : tensor<0xi32>
    %0 = stablehlo.compare  EQ, %c, %c_0 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
    %1 = stablehlo.select %0, %c, %c : tensor<0xi1>, tensor<0xi32>
    %2 = stablehlo.dynamic_broadcast_in_dim %arg0, %1, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
    %3 = stablehlo.dynamic_broadcast_in_dim %arg1, %1, dims = [] : (tensor<i32>, tensor<0xi32>) -> tensor<i32>
    %4 = stablehlo.multiply %2, %3 : tensor<i32>
    return %4 : tensor<i32>
  }
}