nod-ai / SHARK-ModelDev

Unified compiler/runtime for interfacing with PyTorch Dynamo.
Apache License 2.0
94 stars 48 forks source link

crash: mlir::PatternApplicator::matchAndRewrite #867

Open pdhirajkumarprasad opened 2 weeks ago

pdhirajkumarprasad commented 2 weeks ago

for given IR

module {
  func.func @torch_jit(%arg0: !torch.vtensor<[1,3,224,224],f32>, %arg2: !torch.vtensor<[?,64,?,?],f32> , %arg3: !torch.vtensor<[?,?,?,?,?,?],f32> , %arg4: !torch.vtensor<[1],si64>         ) -> !torch.vtensor<[?,?,?,?],f32>    attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "1.12.1"} {
    %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64x64x3x3xf32>} : () -> !torch.vtensor<[64,64,3,3],f32> 
    %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64xf32>} : () -> !torch.vtensor<[64],f32> 
    %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64xf32>} : () -> !torch.vtensor<[64],f32> 
    %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64xf32>} : () -> !torch.vtensor<[64],f32> 
    %4 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64xf32>} : () -> !torch.vtensor<[64],f32> 
    %5 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64xf32>} : () -> !torch.vtensor<[64],f32> 
    %6 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<16x256x1x1xf32>} : () -> !torch.vtensor<[16,256,1,1],f32> 
    %7 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<16xf32>} : () -> !torch.vtensor<[16],f32> 
    %8 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256x16x1x1xf32>} : () -> !torch.vtensor<[256,16,1,1],f32> 
    %9 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256xf32>} : () -> !torch.vtensor<[256],f32> 
    %10 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64x256x1x1xf32>} : () -> !torch.vtensor<[64,256,1,1],f32> 
    %11 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64xf32>} : () -> !torch.vtensor<[64],f32> 
    %12 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64xf32>} : () -> !torch.vtensor<[64],f32> 
    %13 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64xf32>} : () -> !torch.vtensor<[64],f32> 
    %14 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<192xf32>} : () -> !torch.vtensor<[192],f32> 
    %322 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64x3x3x3xf32>} : () -> !torch.vtensor<[64,3,3,3],f32> 
    %323 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64xf32>} : () -> !torch.vtensor<[64],f32> 
    %324 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256x64x1x1xf32>} : () -> !torch.vtensor<[256,64,1,1],f32> 
    %325 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256xf32>} : () -> !torch.vtensor<[256],f32> 
    %326 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256x1x3x3xf32>} : () -> !torch.vtensor<[256,1,3,3],f32> 
    %327 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<256xf32>} : () -> !torch.vtensor<[256],f32> 
    %368 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %369 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<4xsi64>} : () -> !torch.vtensor<[4],si64> 
    %370 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %371 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %372 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %373 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<0.0> : tensor<64x192xf32>} : () -> !torch.vtensor<[64,192],f32> 
    %374 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_6281> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %375 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_6282> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %376 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_onnx__Concat_6283> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %788 = torch.operator "onnx.Identity"(%372) : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %789 = torch.operator "onnx.Identity"(%371) : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %790 = torch.operator "onnx.Identity"(%371) : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %791 = torch.operator "onnx.Identity"(%371) : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %792 = torch.operator "onnx.ConstantOfShape"(%368) {torch.onnx.value = dense_resource<_> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %793 = torch.operator "onnx.Concat"(%369, %792) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> 
    %794 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__1> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %795 = torch.operator "onnx.Reshape"(%793, %794) : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> 
    %796 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__2> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %797 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__3> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %798 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__4> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %799 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__5> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %800 = torch.operator "onnx.Slice"(%795, %797, %798, %796, %799) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> 
    %801 = torch.operator "onnx.Transpose"(%800) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> 
    %802 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__6> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %803 = torch.operator "onnx.Reshape"(%801, %802) : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> 
    %804 = torch.operator "onnx.Cast"(%803) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> 
    %805 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__7> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %806 = torch.operator "onnx.Pad"(%arg0, %804, %805) {torch.onnx.mode = "constant"} : (!torch.vtensor<[1,3,224,224],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %807 = torch.operator "onnx.Conv"(%806, %322, %323) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[64,3,3,3],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %808 = torch.operator "onnx.Mul"(%807, %807) : (!torch.vtensor<[?,64,?,?],f32>, !torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %809 = torch.operator "onnx.Mul"(%807, %808) : (!torch.vtensor<[?,64,?,?],f32>, !torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %810 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__8> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %811 = torch.operator "onnx.Mul"(%810, %809) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %812 = torch.operator "onnx.Add"(%807, %811) : (!torch.vtensor<[?,64,?,?],f32>, !torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %813 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__9> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %814 = torch.operator "onnx.Mul"(%813, %812) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %815 = torch.operator "onnx.Tanh"(%814) : (!torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %816 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__10> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %817 = torch.operator "onnx.Add"(%816, %815) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %818 = torch.operator "onnx.Mul"(%807, %817) : (!torch.vtensor<[?,64,?,?],f32>, !torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %819 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__11> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %820 = torch.operator "onnx.Mul"(%819, %818) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %821 = torch.operator "onnx.Conv"(%820, %0, %1) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[?,64,?,?],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %822 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__12> : tensor<8xsi64>} : () -> !torch.vtensor<[8],si64> 
    %823 = torch.operator "onnx.Pad"(%821, %822) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,64,?,?],f32>, !torch.vtensor<[8],si64>) -> !torch.vtensor<[?,64,?,?],f32> 
    %824 = torch.operator "onnx.AveragePool"(%823) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %825 = torch.operator "onnx.BatchNormalization"(%821, %2, %3, %4, %5) {torch.onnx.epsilon = 1.000000e-03 : f32, torch.onnx.momentum = 0.899999976 : f32} : (!torch.vtensor<[?,64,?,?],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %826 = torch.operator "onnx.Conv"(%825, %324, %325) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [1 : si64, 1 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[?,64,?,?],f32>, !torch.vtensor<[256,64,1,1],f32>, !torch.vtensor<[256],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %827 = torch.operator "onnx.Mul"(%826, %826) : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %828 = torch.operator "onnx.Mul"(%826, %827) : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %829 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__13> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %830 = torch.operator "onnx.Mul"(%829, %828) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %831 = torch.operator "onnx.Add"(%826, %830) : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %832 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__14> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %833 = torch.operator "onnx.Mul"(%832, %831) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %834 = torch.operator "onnx.Tanh"(%833) : (!torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %835 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__15> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %836 = torch.operator "onnx.Add"(%835, %834) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %837 = torch.operator "onnx.Mul"(%826, %836) : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %838 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__16> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %839 = torch.operator "onnx.Mul"(%838, %837) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %840 = torch.operator "onnx.Shape"(%839) : (!torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[4],si64> 
    %841 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__17> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %842 = torch.operator "onnx.Gather"(%840, %841) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %843 = torch.operator "onnx.Shape"(%839) : (!torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[4],si64> 
    %844 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__18> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %845 = torch.operator "onnx.Gather"(%843, %844) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %846 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__19> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %847 = torch.operator "onnx.Sub"(%846, %842) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %848 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__20> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %849 = torch.operator "onnx.Sub"(%848, %845) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %850 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__21> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %851 = torch.operator "onnx.Div"(%849, %850) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %852 = torch.operator "onnx.Cast"(%851) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %853 = torch.operator "onnx.Cast"(%852) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %854 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__22> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %855 = torch.operator "onnx.Div"(%849, %854) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %856 = torch.operator "onnx.Cast"(%855) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %857 = torch.operator "onnx.Cast"(%856) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %858 = torch.operator "onnx.Sub"(%849, %857) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %859 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__23> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %860 = torch.operator "onnx.Div"(%847, %859) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %861 = torch.operator "onnx.Cast"(%860) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %862 = torch.operator "onnx.Cast"(%861) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %863 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__24> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %864 = torch.operator "onnx.Div"(%847, %863) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %865 = torch.operator "onnx.Cast"(%864) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %866 = torch.operator "onnx.Cast"(%865) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %867 = torch.operator "onnx.Sub"(%847, %866) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %868 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__25> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %869 = torch.operator "onnx.Unsqueeze"(%853, %868) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %870 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__26> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %871 = torch.operator "onnx.Unsqueeze"(%858, %870) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %872 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__27> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %873 = torch.operator "onnx.Unsqueeze"(%862, %872) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %874 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__28> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %875 = torch.operator "onnx.Unsqueeze"(%867, %874) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %876 = torch.operator "onnx.Concat"(%869, %871, %873, %875) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %877 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__29> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %878 = torch.operator "onnx.Unsqueeze"(%853, %877) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %879 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__30> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %880 = torch.operator "onnx.Unsqueeze"(%858, %879) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %881 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__31> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %882 = torch.operator "onnx.Unsqueeze"(%862, %881) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %883 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__32> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %884 = torch.operator "onnx.Unsqueeze"(%867, %883) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %885 = torch.operator "onnx.Concat"(%878, %880, %882, %884) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %886 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__33> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %888 = torch.operator "onnx.Gather"(%arg4, %886) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %889 = torch.operator "onnx.Sub"(%370, %888) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %890 = torch.operator "onnx.Cast"(%885) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> 
    %891 = torch.operator "onnx.ConstantOfShape"(%889) {torch.onnx.value = dense_resource<__34> : tensor<1xsi64>} : (!torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %892 = torch.operator "onnx.Concat"(%890, %891) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[8],si64> 
    %893 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__35> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %894 = torch.operator "onnx.Reshape"(%892, %893) : (!torch.vtensor<[8],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,2],si64> 
    %895 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__36> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %896 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__37> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %897 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__38> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %898 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__39> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %899 = torch.operator "onnx.Slice"(%894, %896, %897, %895, %898) : (!torch.vtensor<[4,2],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,2],si64> 
    %900 = torch.operator "onnx.Transpose"(%899) {torch.onnx.perm = [1 : si64, 0 : si64]} : (!torch.vtensor<[4,2],si64>) -> !torch.vtensor<[2,4],si64> 
    %901 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__40> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %902 = torch.operator "onnx.Reshape"(%900, %901) : (!torch.vtensor<[2,4],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[8],si64> 
    %903 = torch.operator "onnx.Cast"(%902) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> 
    %904 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__41> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %905 = torch.operator "onnx.Pad"(%839, %903, %904) {torch.onnx.mode = "constant"} : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[8],si64>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,?],f32> 
    %906 = torch.operator "onnx.Conv"(%905, %326, %327) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 256 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[256,1,3,3],f32>, !torch.vtensor<[256],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %907 = torch.operator "onnx.Mul"(%906, %906) : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %908 = torch.operator "onnx.Mul"(%906, %907) : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %909 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__42> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %910 = torch.operator "onnx.Mul"(%909, %908) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %911 = torch.operator "onnx.Add"(%906, %910) : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %912 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__43> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %913 = torch.operator "onnx.Mul"(%912, %911) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %914 = torch.operator "onnx.Tanh"(%913) : (!torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %915 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__44> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %916 = torch.operator "onnx.Add"(%915, %914) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %917 = torch.operator "onnx.Mul"(%906, %916) : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %918 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__45> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %919 = torch.operator "onnx.Mul"(%918, %917) : (!torch.vtensor<[],f32>, !torch.vtensor<[?,256,?,?],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %920 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<[2, 3]> : tensor<2xsi64>} : () -> !torch.vtensor<[2],si64> 
    %921 = torch.operator "onnx.ReduceMean"(%919, %920) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[?,256,1,1],f32> 
    %922 = torch.operator "onnx.Conv"(%921, %6, %7) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [1 : si64, 1 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[?,256,1,1],f32>, !torch.vtensor<[16,256,1,1],f32>, !torch.vtensor<[16],f32>) -> !torch.vtensor<[?,16,1,1],f32> 
    %923 = torch.operator "onnx.Sigmoid"(%922) : (!torch.vtensor<[?,16,1,1],f32>) -> !torch.vtensor<[?,16,1,1],f32> 
    %924 = torch.operator "onnx.Mul"(%922, %923) : (!torch.vtensor<[?,16,1,1],f32>, !torch.vtensor<[?,16,1,1],f32>) -> !torch.vtensor<[?,16,1,1],f32> 
    %925 = torch.operator "onnx.Conv"(%924, %8, %9) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [1 : si64, 1 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[?,16,1,1],f32>, !torch.vtensor<[256,16,1,1],f32>, !torch.vtensor<[256],f32>) -> !torch.vtensor<[?,256,1,1],f32> 
    %926 = torch.operator "onnx.Sigmoid"(%925) : (!torch.vtensor<[?,256,1,1],f32>) -> !torch.vtensor<[?,256,1,1],f32> 
    %927 = torch.operator "onnx.Mul"(%919, %926) : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[?,256,1,1],f32>) -> !torch.vtensor<[?,256,?,?],f32> 
    %928 = torch.operator "onnx.Conv"(%927, %10, %11) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [1 : si64, 1 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[?,256,?,?],f32>, !torch.vtensor<[64,256,1,1],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %929 = torch.operator "onnx.Add"(%928, %824) : (!torch.vtensor<[?,64,?,?],f32>, !torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,64,?,?],f32> 
    %930 = torch.operator "onnx.Transpose"(%929) {torch.onnx.perm = [0 : si64, 2 : si64, 3 : si64, 1 : si64]} : (!torch.vtensor<[?,64,?,?],f32>) -> !torch.vtensor<[?,?,?,64],f32> 
    %931 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %932 = torch.operator "onnx.ReduceMean"(%930, %931) : (!torch.vtensor<[?,?,?,64],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,1],f32> 
    %933 = torch.operator "onnx.Sub"(%930, %932) : (!torch.vtensor<[?,?,?,64],f32>, !torch.vtensor<[?,?,?,1],f32>) -> !torch.vtensor<[?,?,?,64],f32> 
    %934 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__46> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %935 = torch.operator "onnx.Pow"(%933, %934) : (!torch.vtensor<[?,?,?,64],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,64],f32> 
    %936 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<-1> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %937 = torch.operator "onnx.ReduceMean"(%935, %936) : (!torch.vtensor<[?,?,?,64],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,1],f32> 
    %938 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__47> : tensor<f32>} : () -> !torch.vtensor<[],f32> 
    %939 = torch.operator "onnx.Add"(%937, %938) : (!torch.vtensor<[?,?,?,1],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[?,?,?,1],f32> 
    %940 = torch.operator "onnx.Sqrt"(%939) : (!torch.vtensor<[?,?,?,1],f32>) -> !torch.vtensor<[?,?,?,1],f32> 
    %941 = torch.operator "onnx.Div"(%933, %940) : (!torch.vtensor<[?,?,?,64],f32>, !torch.vtensor<[?,?,?,1],f32>) -> !torch.vtensor<[?,?,?,64],f32> 
    %942 = torch.operator "onnx.Mul"(%941, %12) : (!torch.vtensor<[?,?,?,64],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?,?,64],f32> 
    %943 = torch.operator "onnx.Add"(%942, %13) : (!torch.vtensor<[?,?,?,64],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,?,?,64],f32> 
    %944 = torch.operator "onnx.Shape"(%943) : (!torch.vtensor<[?,?,?,64],f32>) -> !torch.vtensor<[4],si64> 
    %945 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__48> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %946 = torch.operator "onnx.Gather"(%944, %945) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %947 = torch.operator "onnx.Shape"(%943) : (!torch.vtensor<[?,?,?,64],f32>) -> !torch.vtensor<[4],si64> 
    %948 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__49> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %949 = torch.operator "onnx.Gather"(%947, %948) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %950 = torch.operator "onnx.Shape"(%943) : (!torch.vtensor<[?,?,?,64],f32>) -> !torch.vtensor<[4],si64> 
    %951 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__50> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %952 = torch.operator "onnx.Gather"(%950, %951) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %953 = torch.operator "onnx.Shape"(%943) : (!torch.vtensor<[?,?,?,64],f32>) -> !torch.vtensor<[4],si64> 
    %954 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__51> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %955 = torch.operator "onnx.Gather"(%953, %954) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %956 = torch.operator "onnx.Shape"(%943) : (!torch.vtensor<[?,?,?,64],f32>) -> !torch.vtensor<[4],si64> 
    %957 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__52> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %958 = torch.operator "onnx.Gather"(%956, %957) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %959 = torch.operator "onnx.Shape"(%943) : (!torch.vtensor<[?,?,?,64],f32>) -> !torch.vtensor<[4],si64> 
    %960 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__53> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %961 = torch.operator "onnx.Gather"(%959, %960) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %962 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__54> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %963 = torch.operator "onnx.Div"(%955, %962) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %964 = torch.operator "onnx.Cast"(%963) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %965 = torch.operator "onnx.Cast"(%964) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %966 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__55> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %967 = torch.operator "onnx.Div"(%958, %966) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %968 = torch.operator "onnx.Cast"(%967) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %969 = torch.operator "onnx.Cast"(%968) {torch.onnx.to = 7 : si64} : (!torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %970 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__56> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %971 = torch.operator "onnx.Unsqueeze"(%952, %970) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %972 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__57> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %973 = torch.operator "onnx.Unsqueeze"(%965, %972) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %974 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__58> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %975 = torch.operator "onnx.Unsqueeze"(%969, %974) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %976 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__59> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %977 = torch.operator "onnx.Unsqueeze"(%961, %976) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %978 = torch.operator "onnx.Concat"(%971, %973, %371, %975, %791, %977) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[6],si64> 
    %979 = torch.operator "onnx.Reshape"(%943, %978) : (!torch.vtensor<[?,?,?,64],f32>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[?,?,?,?,?,?],f32> 
    %980 = torch.operator "onnx.Transpose"(%979) {torch.onnx.perm = [0 : si64, 1 : si64, 3 : si64, 2 : si64, 4 : si64, 5 : si64]} : (!torch.vtensor<[?,?,?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?,?,?],f32> 
    %981 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__60> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %982 = torch.operator "onnx.Unsqueeze"(%961, %981) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %983 = torch.operator "onnx.Concat"(%372, %790, %789, %982) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],si64> 
    %984 = torch.operator "onnx.Reshape"(%980, %983) : (!torch.vtensor<[?,?,?,?,?,?],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> 
    %985 = torch.operator "onnx.Shape"(%984) : (!torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[4],si64> 
    %986 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__61> : tensor<si64>} : () -> !torch.vtensor<[],si64> 
    %987 = torch.operator "onnx.Gather"(%985, %986) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[4],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[],si64> 
    %997 = torch.operator "onnx.MatMul"(%984, %373) : (!torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[64,192],f32>) -> !torch.vtensor<[?,?,?,192],f32> 
    %998 = torch.operator "onnx.Add"(%14, %997) : (!torch.vtensor<[192],f32>, !torch.vtensor<[?,?,?,192],f32>) -> !torch.vtensor<[?,?,?,192],f32> 
    %999 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__65> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %1000 = torch.operator "onnx.Unsqueeze"(%987, %999) : (!torch.vtensor<[],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1],si64> 
    %1001 = torch.operator "onnx.Concat"(%1000, %788, %374, %375, %376) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[5],si64> 
    %1002 = torch.operator "onnx.Reshape"(%998, %1001) : (!torch.vtensor<[?,?,?,192],f32>, !torch.vtensor<[5],si64>) -> !torch.vtensor<[?,?,?,?,?],f32> 
    %1003 = torch.operator "onnx.Transpose"(%1002) {torch.onnx.perm = [0 : si64, 3 : si64, 2 : si64, 1 : si64, 4 : si64]} : (!torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?,?],f32> 
    %1004 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__66> : tensor<3xsi64>} : () -> !torch.vtensor<[3],si64> 
    %1005:3 = torch.operator "onnx.Split"(%1003, %1004) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[3],si64>) -> (!torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[?,?,?,?,?],f32>) 
    %1006 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<__67> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> 
    %1007 = torch.operator "onnx.Squeeze"(%1005#0, %1006) : (!torch.vtensor<[?,?,?,?,?],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[?,?,?,?],f32> 
    return %1007 : !torch.vtensor<[?,?,?,?],f32>
  }
}

{-#
  dialect_resources: {
    builtin: {
      _onnx__Concat_6281: "0x080000000300000000000000",
      _onnx__Concat_6282: "0x080000000200000000000000",
      _onnx__Concat_6283: "0x080000002000000000000000",
      _: "0x080000000000000000000000",
      __1: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
      __2: "0x080000000000000000000000",
      __3: "0x08000000FFFFFFFFFFFFFFFF",
      __4: "0x080000000100000000000080",
      __5: "0x08000000FFFFFFFFFFFFFFFF",
      __6: "0x08000000FFFFFFFFFFFFFFFF",
      __7: "0x0800000000000000",
      __8: "0x080000001327373D",
      __9: "0x080000002A424C3F",
      __10: "0x080000000000803F",
      __11: "0x080000000000003F",
      __12: "0x0800000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
      __13: "0x080000001327373D",
      __14: "0x080000002A424C3F",
      __15: "0x080000000000803F",
      __16: "0x080000000000003F",
      __17: "0x080000000200000000000000",
      __18: "0x080000000300000000000000",
      __19: "0x080000007100000000000000",
      __20: "0x080000007100000000000000",
      __21: "0x080000000200000000000000",
      __22: "0x080000000200000000000000",
      __23: "0x080000000200000000000000",
      __24: "0x080000000200000000000000",
      __25: "0x080000000000000000000000",
      __26: "0x080000000000000000000000",
      __27: "0x080000000000000000000000",
      __28: "0x080000000000000000000000",
      __29: "0x080000000000000000000000",
      __30: "0x080000000000000000000000",
      __31: "0x080000000000000000000000",
      __32: "0x080000000000000000000000",
      __33: "0x080000000000000000000000",
      __34: "0x080000000000000000000000",
      __35: "0x08000000FFFFFFFFFFFFFFFF0200000000000000",
      __36: "0x080000000000000000000000",
      __37: "0x08000000FFFFFFFFFFFFFFFF",
      __38: "0x080000000100000000000080",
      __39: "0x08000000FFFFFFFFFFFFFFFF",
      __40: "0x08000000FFFFFFFFFFFFFFFF",
      __41: "0x0800000000000000",
      __42: "0x080000001327373D",
      __43: "0x080000002A424C3F",
      __44: "0x080000000000803F",
      __45: "0x080000000000003F",
      __46: "0x0800000000000040",
      __47: "0x08000000ACC52737",
      __48: "0x080000000100000000000000",
      __49: "0x080000000200000000000000",
      __50: "0x080000000000000000000000",
      __51: "0x080000000100000000000000",
      __52: "0x080000000200000000000000",
      __53: "0x080000000300000000000000",
      __54: "0x080000000700000000000000",
      __55: "0x080000000700000000000000",
      __56: "0x080000000000000000000000",
      __57: "0x080000000000000000000000",
      __58: "0x080000000000000000000000",
      __59: "0x080000000000000000000000",
      __60: "0x080000000000000000000000",
      __61: "0x080000000000000000000000",
      __62: "0x080000000000000000000000",
      __63: "0x080000000100000000000000",
      __64: "0x080000000200000000000000",
      __65: "0x080000000000000000000000",
      __66: "0x08000000010000000000000001000000000000000100000000000000",
      __67: "0x080000000200000000000000"
    }
  }
#-}

command : iree-compile --iree-hal-target-backends=llvm-cpu model.torch_onnx.mlir

zjgarvey commented 2 weeks ago

We really need to figure out how to untangle pads from torch exported models.

vivekkhandelwal1 commented 1 week ago

This issue is fixed by https://github.com/llvm/llvm-project/pull/113551.

zjgarvey commented 1 week ago

Testing the failing models with pad sizes folding patch alone: https://github.com/llvm/torch-mlir/pull/3813

Passing Summary

TOTAL TESTS = 41 Stage # Passing % of Total % of Attempted
Setup 41 100.0% 100.0%
IREE Compilation 23 56.1% 56.1%
Gold Inference 23 56.1% 100.0%
IREE Inference Invocation 22 53.7% 95.7%
Inference Comparison (PASS) 17 41.5% 77.3%

Fail Summary

TOTAL TESTS = 41 Stage # Failed at Stage % of Total
Setup 0 0.0%
IREE Compilation 18 43.9%
Gold Inference 0 0.0%
IREE Inference Invocation 1 2.4%
Inference Comparison 5 12.2%

Test Run Detail

Test was run with the following arguments: Namespace(device='local-task', backend='llvm-cpu', iree_compile_args=None, mode='cl-onnx-iree', torchtolinalg=True, stages=None, skip_stages=None, benchmark=False, load_inputs=False, groups='all', test_filter=None, testsfile='sample.txt', tolerance=None, verbose=True, rundirectory='test-run', no_artifacts=False, cleanup='0', report=True, report_file='sample.md', get_metadata=True)

Test Exit Status Mean Benchmark Time (ms) Notes
edgenext_base Numerics None
edgenext_small Numerics None
edgenext_small_rw PASS None
edgenext_x_small Numerics None
edgenext_xx_small Numerics None
maxvit_base_tf_224.in1k compilation None
maxvit_base_tf_384.in1k compilation None
maxvit_base_tf_384.in21k_ft_in1k compilation None
maxvit_base_tf_512.in1k compilation None
maxvit_base_tf_512.in21k_ft_in1k compilation None
maxvit_large_tf_224.in1k compilation None
maxvit_large_tf_384.in1k compilation None
maxvit_large_tf_384.in21k_ft_in1k compilation None
maxvit_large_tf_512.in1k compilation None
maxvit_large_tf_512.in21k_ft_in1k compilation None
maxvit_small_tf_224.in1k compilation None
maxvit_small_tf_384.in1k compilation None
maxvit_small_tf_512.in1k compilation None
maxvit_tiny_tf_224.in1k compilation None
maxvit_tiny_tf_384.in1k compilation None
maxvit_tiny_tf_512.in1k compilation None
maxvit_xlarge_tf_384.in21k_ft_in1k compilation None
maxvit_xlarge_tf_512.in21k_ft_in1k compilation None
model--codegen-350M-mono--Salesforce Numerics None
model--CodeGen-350M-Multi--xhyi compiled_inference None
model--deberta-italian-question-answering--osiria PASS None
model--deberta-v3-base-qa-en--LLukas22 PASS None
model--deberta-v3-base-squad2--deepset PASS None
model--deberta-v3-base-squad2--navteca PASS None
model--deberta-v3-basesst2all-train--SetFit PASS None
model--deberta-v3-large-squad2--deepset PASS None
model--deberta-v3-large-squad2--sjrhuschlee PASS None
model--deberta-v3-xsmall-squad2--nlpconnect PASS None
model--deberta_squadnewsqa--sophiebottani PASS None
model--mdeberta-v3-base-squad2--sjrhuschlee PASS None
model--microsoft-deberta-v3-large_ner_conll2003--Gladiator PASS None
model--microsoft_deberta-base_squad--Palak PASS None
model--microsoft_deberta-large_squad--Palak PASS None
model--outputs--ankitkupadhyay PASS None
model--reward-model-deberta-v3-large-v2--OpenAssistant PASS None
tnt_s_patch16_224 PASS None
zjgarvey commented 1 week ago

These are the models with redundant tests removed:

edgenext_base edgenext_small edgenext_small_rw edgenext_x_small edgenext_xx_small tnt_s_patch16_224 model--CodeGen-350M-Multi--xhyi model--deberta-italian-question-answering--osiria model--deberta-v3-base-qa-en--LLukas22 model--deberta-v3-basesst2all-train--SetFit model--deberta-v3-large-squad2--deepset model--deberta-v3-xsmall-squad2--nlpconnect model--mdeberta-v3-base-squad2--sjrhuschlee model--microsoft-deberta-v3-large_ner_conll2003--Gladiator model--microsoft_deberta-base_squad--Palak model--microsoft_deberta-large_squad--Palak model--outputs--ankitkupadhyay model--reward-model-deberta-v3-large-v2--OpenAssistant maxvit_base_tf_224.in1k maxvit_base_tf_384.in1k maxvit_base_tf_512.in1k maxvit_large_tf_224.in1k maxvit_large_tf_384.in1k maxvit_large_tf_512.in1k maxvit_small_tf_224.in1k maxvit_small_tf_384.in1k maxvit_small_tf_512.in1k maxvit_tiny_tf_224.in1k maxvit_tiny_tf_384.in1k maxvit_tiny_tf_512.in1k maxvit_xlarge_tf_384.in21k_ft_in1k maxvit_xlarge_tf_512.in21k_ft_in1k