iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.85k stars 618 forks source link

Simplified Concat Operation Produces Incorrect Numerics #18589

Closed zjgarvey closed 1 month ago

zjgarvey commented 1 month ago

What happened?

In this gist, there are two very similar linalg IR reproducers. Compiling and running on CPU, they unexpectedly generate mismatching results.

The only difference between these two IR is the very last sequence:

    %169 = linalg.generic {indexing_maps = [#map10, #map19, #map4], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_13, %168 : tensor<128xf32>, tensor<1x1x128xf32>) outs(%166 : tensor<1x1x128xf32>) {
    ^bb0(%in: f32, %in_87: f32, %out: f32):
      %171 = arith.addf %in, %in_87 : f32
      linalg.yield %171 : f32
    } -> tensor<1x1x128xf32>
    %cast_83 = tensor.cast %169 : tensor<1x1x128xf32> to tensor<?x1x128xf32>
    %collapsed_84 = tensor.collapse_shape %161 [[0], [1, 2], [3, 4]] : tensor<1x64x2x14x14xf32> into tensor<1x128x196xf32>
    %170 = tensor.empty() : tensor<1x196x128xf32>
    %transposed_85 = linalg.transpose ins(%collapsed_84 : tensor<1x128x196xf32>) outs(%170 : tensor<1x196x128xf32>) permutation = [0, 2, 1] 
    %concat_86 = tensor.concat dim(1) %cast_83, %transposed_85 : (tensor<?x1x128xf32>, tensor<1x196x128xf32>) -> tensor<1x197x128xf32>
    return %concat_86 : tensor<1x197x128xf32>
  }
}

Notice the cast operation %cast_83, somewhat stupidly makes the first dim dynamic before passing to %concat_86. However, this IR produces correct numerics.

Alternatively, the simplified version below will give wildly incorrect numerics:

    %169 = linalg.generic {indexing_maps = [#map10, #map19, #map4], iterator_types = ["parallel", "parallel", "parallel"]} ins(%cst_13, %168 : tensor<128xf32>, tensor<1x1x128xf32>) outs(%166 : tensor<1x1x128xf32>) {
    ^bb0(%in: f32, %in_87: f32, %out: f32):
      %171 = arith.addf %in, %in_87 : f32
      linalg.yield %171 : f32
    } -> tensor<1x1x128xf32>
    %collapsed_84 = tensor.collapse_shape %161 [[0], [1, 2], [3, 4]] : tensor<1x64x2x14x14xf32> into tensor<1x128x196xf32>
    %170 = tensor.empty() : tensor<1x196x128xf32>
    %transposed_85 = linalg.transpose ins(%collapsed_84 : tensor<1x128x196xf32>) outs(%170 : tensor<1x196x128xf32>) permutation = [0, 2, 1] 
    %concat_86 = tensor.concat dim(1) %169, %transposed_85 : (tensor<1x1x128xf32>, tensor<1x196x128xf32>) -> tensor<1x197x128xf32>
    return %concat_86 : tensor<1x197x128xf32>
  }
}

Steps to reproduce your issue

  1. Download the two reproducers from this gist.
  2. run iree-compile --iree-hal-target-backends=llvm-cpu correct_numerics.mlir -o correct_numerics.vmfb
  3. run iree-compile --iree-hal-target-backends=llvm-cpu wrong_numerics.mlir -o wrong_numerics.vmfb
  4. generate inputs with something like:
    
    import numpy
    import struct

rng = numpy.random.default_rng(19) a = rng.random((1,3,224,224)).astype(numpy.float32) with open("input.0.bin", "wb") as f: mylist = a.flatten().tolist() bytearr = struct.pack("%sf" % len(mylist), *mylist) f.write(bytearr) f.close()

5. run `iree-run-module --module="correct_numerics.vmfb" --input="1x3x224x224xf32=@input.0.bin"`
6. run `iree-run-module --module="wrong_numerics.vmfb" --input="1x3x224x224xf32=@input.0.bin"`
7. compare the first few values of the output.

### What component(s) does this issue relate to?

Compiler, Runtime

### Version information

local build at commit ae6e5d323ecc63e421c79768087f63dc42490cd2

### Additional context

Reproducer generated from the model `pit_ti_224` in SharkTestSuite. 

Comparing with the output of the add node `%169`, it is clear that there is some issue with the compilation of the simplified IR in `wrong_numerics.mlir`. 

I tried to get a smaller reproducer for this issue but was unsuccessful. E.g., 

```mlir
module {
  ml_program.global private mutable @global_seed(dense<0> : tensor<i64>) : tensor<i64>
  func.func @torch_jit(%arg0: tensor<1x3x2xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<1x4x2xf32> {
    %cast_83 = tensor.cast %arg1 : tensor<1x1x2xf32> to tensor<?x1x2xf32>
    %concat_86 = tensor.concat dim(1) %cast_83, %arg0 : (tensor<?x1x2xf32>, tensor<1x3x2xf32>) -> tensor<1x4x2xf32>
    return %concat_86 : tensor<1x4x2xf32>
  }
}

vs. removing the cast operation does not reproduce this issue.

I have not tried reproducing this issue on other backends/devices.

nirvedhmeshram commented 1 month ago

Adding this to compiler bugs until proven otherwise.

nirvedhmeshram commented 1 month ago

The fact that a smaller repro didnt cause an issue seems like could be a stream allocation issue, i will start with looking there, also @zjgarvey how do you know which of these two numerics are incorrect? with the input generator (thanks for providing that) both outputs are just random numbers..

zjgarvey commented 1 month ago

I was originally using the test suite and comparing with onnxruntime CPU result.

You can verify that the simplified IR is incorrect from IREE alone:

  1. Modify either IR to instead return the result of node %169 (the add node). Call this "add_result.mlir".
  2. Compile "add_result.mlir" to "add_result.vmfb".
  3. Run "add_result.vmfb" on the same input "input.0.bin" and inspect the outputs.

The first few values of this add result should match the first few values of the correct concat operation. By inspection, the unsimplified IR result matches the add result, whereas the simplified IR generates an output with completely different values.

nirvedhmeshram commented 1 month ago

This issue simply comes down to the fact that in the simplified IR we are able to fuse the add into the linalg.batch_mmt4d op which we couldnt fuse in the unsimplified case, and the below dispatch is just giving wrong numerics

      #map11 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
      func.func @torch_jit_dispatch_56_batch_mmt4d_1x1x32x64x1x4x1_f32(%4: tensor<1x1x64x1x1xf32>, %5: tensor<1x32x64x4x1xf32>, %6: tensor<1x1x32x1x4xf32>) -> tensor<1x1x32x1x4xf32>{
        %c0 = arith.constant 0 : index
        %cst = arith.constant 0.000000e+00 : f32
        %7 = tensor.empty() : tensor<1x1x32x1x4xf32>
        %8 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x1x32x1x4xf32>) -> tensor<1x1x32x1x4xf32>
        %9 = linalg.batch_mmt4d ins(%4, %5 : tensor<1x1x64x1x1xf32>, tensor<1x32x64x4x1xf32>) outs(%8 : tensor<1x1x32x1x4xf32>) -> tensor<1x1x32x1x4xf32>
        %10 = linalg.generic {indexing_maps = [#map11, #map11, #map11], 
          iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]} 
          ins(%6, %9 : tensor<1x1x32x1x4xf32>, tensor<1x1x32x1x4xf32>) 
          outs(%7 : tensor<1x1x32x1x4xf32>) {
        ^bb0(%in: f32, %in_0: f32, %out: f32):
          %11 = arith.addf %in, %in_0 : f32
          linalg.yield %11 : f32
        } -> tensor<1x1x32x1x4xf32>
        return %10 : tensor<1x1x32x1x4xf32>
      }

Without the elementwise it gives the correct numerics. Investigating why that would be now.