iree-org / iree

A retargetable MLIR-based machine learning compiler and runtime toolkit.
http://iree.dev/
Apache License 2.0
2.58k stars 577 forks source link

[numerics]: element at index 0 (0.332534) does not match the expected (0.308342); for LSTM ops #18441

Open pdhirajkumarprasad opened 3 weeks ago

pdhirajkumarprasad commented 3 weeks ago

What happened?

for the given IR

module {
  func.func @"torch-jit-export"(%arg0: !torch.vtensor<[35,1],si64>, %arg1: !torch.vtensor<[2,1,200],f32>, %arg3:  !torch.vtensor<[35,1,200],f32>) -> (!torch.vtensor<[35,1,1,200],f32>, !torch.vtensor<[1,1,200],f32>) attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.3"} {
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_encoder.weight> : tensor<33278x200xf32>} : () -> !torch.vtensor<[33278,200],f32> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_rnn.bias_hh_l0> : tensor<600xf32>} : () -> !torch.vtensor<[600],f32> 
    %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_rnn.bias_ih_l0> : tensor<600xf32>} : () -> !torch.vtensor<[600],f32> 
    %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_rnn.weight_hh_l0> : tensor<600x200xf32>} : () -> !torch.vtensor<[600,200],f32> 
    %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_rnn.weight_ih_l0> : tensor<600x200xf32>} : () -> !torch.vtensor<[600,200],f32> 
    %none = torch.constant.none
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<200> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<400> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %9 = torch.operator "onnx.Slice"(%4, %6, %7, %8) : (!torch.vtensor<[600,200],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200,200],f32> 
    %10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %11 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<200> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %13 = torch.operator "onnx.Slice"(%4, %10, %11, %12) : (!torch.vtensor<[600,200],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200,200],f32> 
    %14 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<400> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %15 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<600> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %16 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %17 = torch.operator "onnx.Slice"(%4, %14, %15, %16) : (!torch.vtensor<[600,200],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200,200],f32> 
    %18 = torch.operator "onnx.Concat"(%9, %13, %17) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[200,200],f32>, !torch.vtensor<[200,200],f32>, !torch.vtensor<[200,200],f32>) -> !torch.vtensor<[600,200],f32> 
    %19 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<200> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %20 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<400> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %21 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %22 = torch.operator "onnx.Slice"(%3, %19, %20, %21) : (!torch.vtensor<[600,200],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200,200],f32> 
    %23 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %24 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<200> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %25 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %26 = torch.operator "onnx.Slice"(%3, %23, %24, %25) : (!torch.vtensor<[600,200],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200,200],f32> 
    %27 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<400> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %28 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<600> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %29 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %30 = torch.operator "onnx.Slice"(%3, %27, %28, %29) : (!torch.vtensor<[600,200],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200,200],f32> 
    %31 = torch.operator "onnx.Concat"(%22, %26, %30) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[200,200],f32>, !torch.vtensor<[200,200],f32>, !torch.vtensor<[200,200],f32>) -> !torch.vtensor<[600,200],f32> 
    %32 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<200> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %33 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<400> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %34 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %35 = torch.operator "onnx.Slice"(%2, %32, %33, %34) : (!torch.vtensor<[600],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200],f32> 
    %36 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %37 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<200> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %38 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %39 = torch.operator "onnx.Slice"(%2, %36, %37, %38) : (!torch.vtensor<[600],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200],f32> 
    %40 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<400> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %41 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<600> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %42 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %43 = torch.operator "onnx.Slice"(%2, %40, %41, %42) : (!torch.vtensor<[600],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200],f32> 
    %44 = torch.operator "onnx.Concat"(%35, %39, %43) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[200],f32>, !torch.vtensor<[200],f32>, !torch.vtensor<[200],f32>) -> !torch.vtensor<[600],f32> 
    %45 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<200> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %46 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<400> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %47 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %48 = torch.operator "onnx.Slice"(%1, %45, %46, %47) : (!torch.vtensor<[600],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200],f32> 
    %49 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %50 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<200> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %51 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %52 = torch.operator "onnx.Slice"(%1, %49, %50, %51) : (!torch.vtensor<[600],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200],f32> 
    %53 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<400> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %54 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<600> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %55 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %56 = torch.operator "onnx.Slice"(%1, %53, %54, %55) : (!torch.vtensor<[600],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[200],f32> 
    %57 = torch.operator "onnx.Concat"(%48, %52, %56) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[200],f32>, !torch.vtensor<[200],f32>, !torch.vtensor<[200],f32>) -> !torch.vtensor<[600],f32> 
    %58 = torch.operator "onnx.Concat"(%44, %57) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[600],f32>, !torch.vtensor<[600],f32>) -> !torch.vtensor<[1200],f32> 
    %59 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %60 = torch.operator "onnx.Unsqueeze"(%18, %59) : (!torch.vtensor<[600,200],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,600,200],f32> 
    %61 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %62 = torch.operator "onnx.Unsqueeze"(%31, %61) : (!torch.vtensor<[600,200],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,600,200],f32> 
    %63 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %64 = torch.operator "onnx.Unsqueeze"(%58, %63) : (!torch.vtensor<[1200],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,1200],f32> 
    %65 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %66 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %67 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %68 = torch.operator "onnx.Slice"(%arg1, %65, %66, %67) : (!torch.vtensor<[2,1,200],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,1,200],f32> 
    %69:2 = torch.operator "onnx.GRU"(%arg3, %60, %62, %64, %none, %68) {torch.onnx.hidden_size = 200 : si64, torch.onnx.layout = 0 : si64, torch.onnx.linear_before_reset = 1 : si64} : (!torch.vtensor<[35,1,200],f32>, !torch.vtensor<[1,600,200],f32>, !torch.vtensor<[1,600,200],f32>, !torch.vtensor<[1,1200],f32>, !torch.none, !torch.vtensor<[1,1,200],f32>) -> (!torch.vtensor<[35,1,1,200],f32>, !torch.vtensor<[1,1,200],f32>) 
    return %69#0, %69#1 : !torch.vtensor<[35,1,1,200],f32>, !torch.vtensor<[1,1,200],f32>
  }
}

We are seeing numeric mismatch. LSTM operator as standalone is working fine but when we have multiple operator with LSTM, seeing this numeric mismatch

File Need: golden_output.0.bin.txt golden_output.1.bin.txt input.0.bin.txt input.1.bin.txt input.2.bin.txt model.torch_onnx.mlir.txt

Steps to reproduce your issue

command to reproduce the issue:

ree-compile model.torch_onnx.mlir --iree-hal-target-backends=llvm-cpu -o out.vmfb --iree-input-demote-i64-to-i32
iree-run-module --module=out.vmfb --device="local-task" --input="35x1xi64=@input.0.bin" --input="2x1x200xf32=@input.1.bin" --input="35x1x200xf32=@input.2.bin" --expected_output="35x1x1x200xf32=@golden_output.0.bin" --expected_output="1x1x200xf32=@golden_output.1.bin"

What component(s) does this issue relate to?

Runtime

Version information

No response

Additional context

No response

MaheshRavishankar commented 3 weeks ago

@pashu123 can you help triage?

pashu123 commented 3 weeks ago

@pdhirajkumarprasad Did you run the same program on the torch-mlir's ref-backend?