nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
95 stars 48 forks source link

GRU Quality Issues #856

Closed zjgarvey closed 1 month ago

zjgarvey commented 1 month ago
  1. test_gru_batchwise fails torch-to-linalg:
module {
  func.func @test_gru_batchwise(%arg0: !torch.vtensor<[3,1,2],f32>, %arg1: !torch.vtensor<[1,18,2],f32>, %arg2: !torch.vtensor<[1,18,6],f32>) -> (!torch.vtensor<[3,1,1,6],f32>, !torch.vtensor<[3,1,6],f32>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0:2 = torch.operator "onnx.GRU"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 6 : si64, torch.onnx.layout = 1 : si64} : (!torch.vtensor<[3,1,2],f32>, !torch.vtensor<[1,18,2],f32>, !torch.vtensor<[1,18,6],f32>) -> (!torch.vtensor<[3,1,1,6],f32>, !torch.vtensor<[3,1,6],f32>) 
    return %0#0, %0#1 : !torch.vtensor<[3,1,1,6],f32>, !torch.vtensor<[3,1,6],f32>
  }
}

reproduce : torch-mlir-opt --convert-torch-onnx-to-torch --torch-lower-to-backend-contract --convert-torch-to-linalg

Looks like the torch-onnx-to-torch is generating invalid ir (shapes don't make sense).

  1. test_gru_defaults fails torch-onnx-to-torch:
module {
  func.func @test_gru_defaults(%arg0: !torch.vtensor<[1,3,2],f32>, %arg1: !torch.vtensor<[1,15,2],f32>, %arg2: !torch.vtensor<[1,15,5],f32>) -> !torch.vtensor<[1,3,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0:2 = torch.operator "onnx.GRU"(%arg0, %arg1, %arg2) {torch.onnx.hidden_size = 5 : si64} : (!torch.vtensor<[1,3,2],f32>, !torch.vtensor<[1,15,2],f32>, !torch.vtensor<[1,15,5],f32>) -> (!torch.none, !torch.vtensor<[1,3,5],f32>) 
    return %0#1 : !torch.vtensor<[1,3,5],f32>
  }
}

hitting a match failure in torch-onnx-to-torch relating to number of output tensors.

reproduce : torch-mlir-opt --convert-torch-onnx-to-torch