NVIDIA / TensorRT-Incubator

Experimental projects related to TensorRT
81 stars 12 forks source link

Segfault in mlir-tensorrt #208

Closed parthchadha closed 1 week ago

parthchadha commented 2 months ago
module @outs_t27_4 {
  func.func @main() -> tensor<?xf32> {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %cst_1 = stablehlo.constant dense<5.000000e+00> : tensor<f32>
    %c = stablehlo.constant dense<> : tensor<0xi32>
    %c_2 = stablehlo.constant dense<> : tensor<0xi32>
    %0 = stablehlo.compare  EQ, %c, %c_2 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
    %1 = stablehlo.select %0, %c, %c : tensor<0xi1>, tensor<0xi32>
    %2 = stablehlo.dynamic_broadcast_in_dim %cst_0, %1, dims = [] : (tensor<f32>, tensor<0xi32>) -> tensor<f32>
    %3 = stablehlo.dynamic_broadcast_in_dim %cst_1, %1, dims = [] : (tensor<f32>, tensor<0xi32>) -> tensor<f32>
    %4 = stablehlo.subtract %2, %3 : tensor<f32>
    %cst_3 = stablehlo.constant dense<5.000000e-01> : tensor<f32>
    %5 = stablehlo.compare  EQ, %c, %c_2 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
    %6 = stablehlo.select %5, %c, %c : tensor<0xi1>, tensor<0xi32>
    %7 = stablehlo.dynamic_broadcast_in_dim %4, %6, dims = [] : (tensor<f32>, tensor<0xi32>) -> tensor<f32>
    %8 = stablehlo.dynamic_broadcast_in_dim %cst_3, %6, dims = [] : (tensor<f32>, tensor<0xi32>) -> tensor<f32>
    %9 = stablehlo.divide %7, %8 : tensor<f32>
    %10 = stablehlo.floor %9 : tensor<f32>
    %11 = stablehlo.compare  EQ, %c, %c_2 : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi1>
    %12 = stablehlo.select %11, %c, %c : tensor<0xi1>, tensor<0xi32>
    %13 = stablehlo.dynamic_broadcast_in_dim %cst, %12, dims = [] : (tensor<f32>, tensor<0xi32>) -> tensor<f32>
    %14 = stablehlo.dynamic_broadcast_in_dim %10, %12, dims = [] : (tensor<f32>, tensor<0xi32>) -> tensor<f32>
    %15 = stablehlo.subtract %13, %14 : tensor<f32>
    %16 = stablehlo.convert %15 : (tensor<f32>) -> tensor<i32>
    %c_4 = stablehlo.constant dense<1> : tensor<1xi32>
    %17 = stablehlo.dynamic_broadcast_in_dim %16, %c_4, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %18 = stablehlo.dynamic_iota %17, dim = 0 : (tensor<1xi32>) -> tensor<?xf32>
    %19 = stablehlo.dynamic_broadcast_in_dim %16, %c_4, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %20 = stablehlo.dynamic_broadcast_in_dim %cst_3, %19, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
    %21 = stablehlo.get_dimension_size %18, dim = 0 : (tensor<?xf32>) -> tensor<i32>
    %22 = stablehlo.reshape %21 : (tensor<i32>) -> tensor<1xi32>
    %23 = stablehlo.compare  EQ, %22, %c_4 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %24 = stablehlo.get_dimension_size %20, dim = 0 : (tensor<?xf32>) -> tensor<i32>
    %25 = stablehlo.reshape %24 : (tensor<i32>) -> tensor<1xi32>
    %26 = stablehlo.select %23, %25, %22 : tensor<1xi1>, tensor<1xi32>
    %27 = stablehlo.dynamic_broadcast_in_dim %18, %26, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    %28 = stablehlo.dynamic_broadcast_in_dim %20, %26, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    %29 = stablehlo.multiply %27, %28 : tensor<?xf32>
    %30 = stablehlo.dynamic_broadcast_in_dim %16, %c_4, dims = [] : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
    %31 = stablehlo.dynamic_broadcast_in_dim %cst_0, %30, dims = [] : (tensor<f32>, tensor<1xi32>) -> tensor<?xf32>
    %32 = stablehlo.get_dimension_size %29, dim = 0 : (tensor<?xf32>) -> tensor<i32>
    %33 = stablehlo.reshape %32 : (tensor<i32>) -> tensor<1xi32>
    %34 = stablehlo.compare  EQ, %33, %c_4 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi1>
    %35 = stablehlo.get_dimension_size %31, dim = 0 : (tensor<?xf32>) -> tensor<i32>
    %36 = stablehlo.reshape %35 : (tensor<i32>) -> tensor<1xi32>
    %37 = stablehlo.select %34, %36, %33 : tensor<1xi1>, tensor<1xi32>
    %38 = stablehlo.dynamic_broadcast_in_dim %29, %37, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    %39 = stablehlo.dynamic_broadcast_in_dim %31, %37, dims = [0] : (tensor<?xf32>, tensor<1xi32>) -> tensor<?xf32>
    %40 = stablehlo.add %38, %39 : tensor<?xf32>
    return %40 : tensor<?xf32>
  }
}
christopherbate commented 2 months ago

We have fix internally, will be pushed out in the next sync (tomorrow morning)

christopherbate commented 1 week ago

There is no longer an error on the main branch.