microsoft / onnxscript

ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python.
https://onnxscript.ai/
MIT License
270 stars 50 forks source link

Unsupported FX Nodes: {'call_function': ['aten.roll.default', 'aten.var.correction']} #1173

Closed luisfmnunes closed 9 months ago

luisfmnunes commented 10 months ago

Hello,

First of all, sorry for this post, I'm still kind lost on how ONNX opset 18 works and how TorchDynamo exports the model to an ONNX protobuf. Well I trained my model and now I'm trying to export to ONNX. Using torch.export.export I can generate an ExportedProgram with the following signature:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, arg0_1: f32[1, 1, 3, 3], arg1_1: f32[1, 1, 3, 3], arg2_1: f32[1, 1, 3, 3], arg3_1: f32[1, 1, 3, 3], arg4_1: f32[1, 1, 9, 9], arg5_1: f32[1, 1, 9, 9], arg6_1: f32[64, 3, 7, 7], arg7_1: f32[64], arg8_1: f32[64], arg9_1: f32[64, 64, 3, 3], arg10_1: f32[64], arg11_1: f32[64], arg12_1: f32[64, 64, 3, 3], arg13_1: f32[64], arg14_1: f32[64], arg15_1: f32[64, 64, 3, 3], arg16_1: f32[64], arg17_1: f32[64], arg18_1: f32[64, 64, 3, 3], arg19_1: f32[64], arg20_1: f32[64], arg21_1: f32[128, 64, 3, 3], arg22_1: f32[128], arg23_1: f32[128], arg24_1: f32[128, 128, 3, 3], arg25_1: f32[128], arg26_1: f32[128], arg27_1: f32[128, 64, 1, 1], arg28_1: f32[128], arg29_1: f32[128], arg30_1: f32[128, 128, 3, 3], arg31_1: f32[128], arg32_1: f32[128], arg33_1: f32[128, 128, 3, 3], arg34_1: f32[128], arg35_1: f32[128], arg36_1: f32[256, 128, 3, 3], arg37_1: f32[256], arg38_1: f32[256], arg39_1: f32[256, 256, 3, 3], arg40_1: f32[256], arg41_1: f32[256], arg42_1: f32[256, 128, 1, 1], arg43_1: f32[256], arg44_1: f32[256], arg45_1: f32[256, 256, 3, 3], arg46_1: f32[256], arg47_1: f32[256], arg48_1: f32[256, 256, 3, 3], arg49_1: f32[256], arg50_1: f32[256], arg51_1: f32[512, 256, 3, 3], arg52_1: f32[512], arg53_1: f32[512], arg54_1: f32[512, 512, 3, 3], arg55_1: f32[512], arg56_1: f32[512], arg57_1: f32[512, 256, 1, 1], arg58_1: f32[512], arg59_1: f32[512], arg60_1: f32[512, 512, 3, 3], arg61_1: f32[512], arg62_1: f32[512], arg63_1: f32[512, 512, 3, 3], arg64_1: f32[512], arg65_1: f32[512], arg66_1: f32[512, 256, 2, 2], arg67_1: f32[256], arg68_1: f32[256, 512, 3, 3], arg69_1: f32[256], arg70_1: f32[256], arg71_1: f32[256, 256, 3, 3], arg72_1: f32[256], arg73_1: f32[256], arg74_1: f32[256, 128, 2, 2], arg75_1: f32[128], arg76_1: f32[128, 256, 3, 3], arg77_1: f32[128], arg78_1: f32[128], arg79_1: f32[128, 128, 3, 3], arg80_1: f32[128], arg81_1: f32[128], arg82_1: f32[128, 64, 2, 2], arg83_1: f32[64], arg84_1: f32[64, 128, 3, 3], arg85_1: f32[64], arg86_1: f32[64], arg87_1: f32[64, 64, 3, 3], arg88_1: f32[64], arg89_1: f32[64], arg90_1: f32[134, 64, 1, 1], arg91_1: f32[134], arg92_1: f32[64], arg93_1: f32[64], arg94_1: i64[], arg95_1: f32[64], arg96_1: f32[64], arg97_1: i64[], arg98_1: f32[64], arg99_1: f32[64], arg100_1: i64[], arg101_1: f32[64], arg102_1: f32[64], arg103_1: i64[], arg104_1: f32[64], arg105_1: f32[64], arg106_1: i64[], arg107_1: f32[128], arg108_1: f32[128], arg109_1: i64[], arg110_1: f32[128], arg111_1: f32[128], arg112_1: i64[], arg113_1: f32[128], arg114_1: f32[128], arg115_1: i64[], arg116_1: f32[128], arg117_1: f32[128], arg118_1: i64[], arg119_1: f32[128], arg120_1: f32[128], arg121_1: i64[], arg122_1: f32[256], arg123_1: f32[256], arg124_1: i64[], arg125_1: f32[256], arg126_1: f32[256], arg127_1: i64[], arg128_1: f32[256], arg129_1: f32[256], arg130_1: i64[], arg131_1: f32[256], arg132_1: f32[256], arg133_1: i64[], arg134_1: f32[256], arg135_1: f32[256], arg136_1: i64[], arg137_1: f32[512], arg138_1: f32[512], arg139_1: i64[], arg140_1: f32[512], arg141_1: f32[512], arg142_1: i64[], arg143_1: f32[512], arg144_1: f32[512], arg145_1: i64[], arg146_1: f32[512], arg147_1: f32[512], arg148_1: i64[], arg149_1: f32[512], arg150_1: f32[512], arg151_1: i64[], arg152_1: f32[256], arg153_1: f32[256], arg154_1: i64[], arg155_1: f32[256], arg156_1: f32[256], arg157_1: i64[], arg158_1: f32[128], arg159_1: f32[128], arg160_1: i64[], arg161_1: f32[128], arg162_1: f32[128], arg163_1: i64[], arg164_1: f32[64], arg165_1: f32[64], arg166_1: i64[], arg167_1: f32[64], arg168_1: f32[64], arg169_1: i64[], arg170_1: f32[1, 1, 512, 512]):
            # 
            arange: i64[512] = torch.ops.aten.arange.start_step(0, 512, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt: b8[512] = torch.ops.aten.lt.Scalar(arange, 256.0)
            _to_copy: f32[512] = torch.ops.aten._to_copy.default(arange, dtype = torch.float32)
            mul: f32[512] = torch.ops.aten.mul.Tensor(_to_copy, 0.0019569471624266144);  _to_copy = None
            add: f32[512] = torch.ops.aten.add.Tensor(mul, -0.5);  mul = None
            sub: i64[512] = torch.ops.aten.sub.Tensor(511, arange);  arange = None
            _to_copy_1: f32[512] = torch.ops.aten._to_copy.default(sub, dtype = torch.float32);  sub = None
            mul_1: f32[512] = torch.ops.aten.mul.Tensor(_to_copy_1, 0.0019569471624266144);  _to_copy_1 = None
            sub_1: f32[512] = torch.ops.aten.sub.Tensor(0.5, mul_1);  mul_1 = None
            where: f32[512] = torch.ops.aten.where.self(lt, add, sub_1);  lt = add = sub_1 = None
            arange_1: i64[512] = torch.ops.aten.arange.start_step(0, 512, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_1: b8[512] = torch.ops.aten.lt.Scalar(arange_1, 256.0)
            _to_copy_2: f32[512] = torch.ops.aten._to_copy.default(arange_1, dtype = torch.float32)
            mul_2: f32[512] = torch.ops.aten.mul.Tensor(_to_copy_2, 0.0019569471624266144);  _to_copy_2 = None
            add_1: f32[512] = torch.ops.aten.add.Tensor(mul_2, -0.5);  mul_2 = None
            sub_2: i64[512] = torch.ops.aten.sub.Tensor(511, arange_1);  arange_1 = None
            _to_copy_3: f32[512] = torch.ops.aten._to_copy.default(sub_2, dtype = torch.float32);  sub_2 = None
            mul_3: f32[512] = torch.ops.aten.mul.Tensor(_to_copy_3, 0.0019569471624266144);  _to_copy_3 = None
            sub_3: f32[512] = torch.ops.aten.sub.Tensor(0.5, mul_3);  mul_3 = None
            where_1: f32[512] = torch.ops.aten.where.self(lt_1, add_1, sub_3);  lt_1 = add_1 = sub_3 = None
            view: f32[512, 1] = torch.ops.aten.view.default(where, [-1, 1]);  where = None
            expand: f32[512, 512] = torch.ops.aten.expand.default(view, [512, 512]);  view = None
            view_1: f32[1, 512] = torch.ops.aten.view.default(where_1, [1, -1]);  where_1 = None
            expand_1: f32[512, 512] = torch.ops.aten.expand.default(view_1, [512, 512]);  view_1 = None
            pow_1: f32[512, 512] = torch.ops.aten.pow.Tensor_Scalar(expand_1, 2);  expand_1 = None
            pow_2: f32[512, 512] = torch.ops.aten.pow.Tensor_Scalar(expand, 2);  expand = None
            add_2: f32[512, 512] = torch.ops.aten.add.Tensor(pow_1, pow_2);  pow_1 = pow_2 = None
            sqrt: f32[512, 512] = torch.ops.aten.sqrt.default(add_2);  add_2 = None
            add_3: f32[512, 512] = torch.ops.aten.add.Tensor(sqrt, 1e-06);  sqrt = None
            mul_4: f32[512, 512] = torch.ops.aten.mul.Tensor(add_3, 6.283185307179586);  add_3 = None
            mul_5: f32[512, 512] = torch.ops.aten.mul.Tensor(mul_4, 2.5);  mul_4 = None
            pow_3: f32[512, 512] = torch.ops.aten.pow.Tensor_Scalar(mul_5, 4);  mul_5 = None
            add_4: f32[512, 512] = torch.ops.aten.add.Tensor(pow_3, 1);  pow_3 = None
            reciprocal: f32[512, 512] = torch.ops.aten.reciprocal.default(add_4);  add_4 = None
            mul_6: f32[512, 512] = torch.ops.aten.mul.Tensor(reciprocal, 1.0);  reciprocal = None
            unsqueeze: f32[1, 512, 512] = torch.ops.aten.unsqueeze.default(mul_6, 0);  mul_6 = None
            unsqueeze_1: f32[1, 1, 512, 512] = torch.ops.aten.unsqueeze.default(unsqueeze, 1);  unsqueeze = None
            convolution: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(arg170_1, arg0_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
            convolution_1: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(arg170_1, arg1_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1)
            pow_4: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution, 2);  convolution = None
            pow_5: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_1, 2);  convolution_1 = None
            add_5: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(pow_4, pow_5);  pow_4 = pow_5 = None
            sqrt_1: f32[1, 1, 512, 512] = torch.ops.aten.sqrt.default(add_5);  add_5 = None
            add_6: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(sqrt_1, 1e-06);  sqrt_1 = None
            _to_copy_4: c64[1, 1, 512, 512] = torch.ops.aten._to_copy.default(add_6, dtype = torch.complex64);  add_6 = None
            _fft_c2c: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(_to_copy_4, [2, 3], 0, True);  _to_copy_4 = None
            roll: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(_fft_c2c, [256, 256], [2, 3]);  _fft_c2c = None
            mul_7: c64[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(roll, unsqueeze_1);  roll = None
            roll_1: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(mul_7, [256, 256], [2, 3]);  mul_7 = None
            _fft_c2c_1: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(roll_1, [2, 3], 2, False);  roll_1 = None
            view_as_real: f32[1, 1, 512, 512, 2] = torch.ops.aten.view_as_real.default(_fft_c2c_1);  _fft_c2c_1 = None
            select: f32[1, 1, 512, 512] = torch.ops.aten.select.int(view_as_real, 4, 0);  view_as_real = None
            _to_copy_5: c64[1, 1, 512, 512] = torch.ops.aten._to_copy.default(arg170_1, dtype = torch.complex64)
            _fft_c2c_2: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(_to_copy_5, [2, 3], 0, True);  _to_copy_5 = None
            roll_2: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(_fft_c2c_2, [256, 256], [2, 3]);  _fft_c2c_2 = None
            mul_8: c64[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(roll_2, unsqueeze_1);  roll_2 = None
            roll_3: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(mul_8, [256, 256], [2, 3]);  mul_8 = None
            _fft_c2c_3: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(roll_3, [2, 3], 2, False);  roll_3 = None
            view_as_real_1: f32[1, 1, 512, 512, 2] = torch.ops.aten.view_as_real.default(_fft_c2c_3);  _fft_c2c_3 = None
            select_1: f32[1, 1, 512, 512] = torch.ops.aten.select.int(view_as_real_1, 4, 0);  view_as_real_1 = None
            convolution_2: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(select_1, arg0_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg0_1 = None
            convolution_3: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(select_1, arg1_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg1_1 = None
            pow_6: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_2, 2);  convolution_2 = None
            pow_7: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_3, 2);  convolution_3 = None
            add_7: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(pow_6, pow_7);  pow_6 = pow_7 = None
            sqrt_2: f32[1, 1, 512, 512] = torch.ops.aten.sqrt.default(add_7);  add_7 = None
            add_8: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(sqrt_2, 1e-06);  sqrt_2 = None
            _to_copy_6: c64[1, 1, 512, 512] = torch.ops.aten._to_copy.default(add_8, dtype = torch.complex64);  add_8 = None
            _fft_c2c_4: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(_to_copy_6, [2, 3], 0, True);  _to_copy_6 = None
            roll_4: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(_fft_c2c_4, [256, 256], [2, 3]);  _fft_c2c_4 = None
            mul_9: c64[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(roll_4, unsqueeze_1);  roll_4 = unsqueeze_1 = None
            roll_5: c64[1, 1, 512, 512] = torch.ops.aten.roll.default(mul_9, [256, 256], [2, 3]);  mul_9 = None
            _fft_c2c_5: c64[1, 1, 512, 512] = torch.ops.aten._fft_c2c.default(roll_5, [2, 3], 2, False);  roll_5 = None
            view_as_real_2: f32[1, 1, 512, 512, 2] = torch.ops.aten.view_as_real.default(_fft_c2c_5);  _fft_c2c_5 = None
            select_2: f32[1, 1, 512, 512] = torch.ops.aten.select.int(view_as_real_2, 4, 0);  view_as_real_2 = None
            sub_4: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(select, select_2);  select_2 = None
            abs_1: f32[1, 1, 512, 512] = torch.ops.aten.abs.default(select);  select = None
            gt: b8[1, 1, 512, 512] = torch.ops.aten.gt.Scalar(abs_1, 1)
            clamp_min: f32[1, 1, 512, 512] = torch.ops.aten.clamp_min.default(abs_1, 1e-06);  abs_1 = None
            div: f32[1, 1, 512, 512] = torch.ops.aten.div.Tensor(sub_4, clamp_min);  clamp_min = None
            full_like: f32[1, 1, 512, 512] = torch.ops.aten.full_like.default(sub_4, 0, pin_memory = False, memory_format = torch.preserve_format);  sub_4 = None
            where_2: f32[1, 1, 512, 512] = torch.ops.aten.where.self(gt, div, full_like);  gt = div = full_like = None
            sub_5: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(where_2, 0.3);  where_2 = None
            div_1: f32[1, 1, 512, 512] = torch.ops.aten.div.Tensor(sub_5, 0.39999999999999997);  sub_5 = None
            clamp: f32[1, 1, 512, 512] = torch.ops.aten.clamp.default(div_1, 0, 1);  div_1 = None
            mul_10: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(clamp, select_1);  select_1 = None
            sub_6: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(1, clamp);  clamp = None
            mul_11: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(sub_6, arg170_1);  sub_6 = None
            add_9: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(mul_10, mul_11);  mul_10 = mul_11 = None
            sub_7: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(arg170_1, add_9);  arg170_1 = add_9 = None
            add_10: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(sub_7, 20);  sub_7 = None
            mul_12: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(add_10, 255);  add_10 = None
            div_2: f32[1, 1, 512, 512] = torch.ops.aten.div.Tensor(mul_12, 40);  mul_12 = None
            clamp_1: f32[1, 1, 512, 512] = torch.ops.aten.clamp.default(div_2, 0, 255);  div_2 = None
            mean: f32[1, 1, 1, 1] = torch.ops.aten.mean.dim(clamp_1, [1, 2, 3], True)
            var: f32[1, 1, 1, 1] = torch.ops.aten.var.correction(clamp_1, [1, 2, 3], correction = 1, keepdim = True)
            sub_8: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(clamp_1, mean)
            pow_8: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(sub_8, 2);  sub_8 = None
            mul_13: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(pow_8, 1);  pow_8 = None
            clamp_min_1: f32[1, 1, 1, 1] = torch.ops.aten.clamp_min.default(var, 1e-06);  var = None
            div_3: f32[1, 1, 512, 512] = torch.ops.aten.div.Tensor(mul_13, clamp_min_1);  mul_13 = clamp_min_1 = None
            sqrt_3: f32[1, 1, 512, 512] = torch.ops.aten.sqrt.default(div_3);  div_3 = None
            gt_1: b8[1, 1, 512, 512] = torch.ops.aten.gt.Tensor(clamp_1, mean);  mean = None
            add_11: f32[1, 1, 512, 512] = torch.ops.aten.add.Tensor(sqrt_3, 0)
            sub_9: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(0, sqrt_3);  sqrt_3 = None
            where_3: f32[1, 1, 512, 512] = torch.ops.aten.where.self(gt_1, add_11, sub_9);  gt_1 = add_11 = sub_9 = None
            convolution_4: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(where_3, arg2_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg2_1 = None
            convolution_5: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(where_3, arg3_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg3_1 = None
            pow_9: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_4, 2)
            convolution_6: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(pow_9, arg4_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  pow_9 = None
            pow_10: f32[1, 1, 512, 512] = torch.ops.aten.pow.Tensor_Scalar(convolution_5, 2)
            convolution_7: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(pow_10, arg4_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  pow_10 = None
            neg: f32[1, 1, 512, 512] = torch.ops.aten.neg.default(convolution_4);  convolution_4 = None
            mul_14: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(neg, convolution_5);  neg = convolution_5 = None
            convolution_8: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(mul_14, arg4_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  mul_14 = arg4_1 = None
            mul_15: f32[1, 1, 512, 512] = torch.ops.aten.mul.Tensor(convolution_8, 2);  convolution_8 = None
            convolution_9: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(mul_15, arg5_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  mul_15 = None
            sub_10: f32[1, 1, 512, 512] = torch.ops.aten.sub.Tensor(convolution_6, convolution_7);  convolution_6 = convolution_7 = None
            convolution_10: f32[1, 1, 512, 512] = torch.ops.aten.convolution.default(sub_10, arg5_1, None, [1, 1], [4, 4], [1, 1], False, [0, 0], 1);  sub_10 = arg5_1 = None
            cat: f32[1, 3, 512, 512] = torch.ops.aten.cat.default([where_3, convolution_9, convolution_10], 1);  where_3 = convolution_9 = convolution_10 = None
            convolution_11: f32[1, 64, 256, 256] = torch.ops.aten.convolution.default(cat, arg6_1, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1);  cat = arg6_1 = None
            _native_batch_norm_legit_no_training = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_11, arg7_1, arg8_1, arg92_1, arg93_1, 0.1, 1e-05);  convolution_11 = arg7_1 = arg8_1 = arg92_1 = arg93_1 = None
            getitem: f32[1, 64, 256, 256] = _native_batch_norm_legit_no_training[0];  _native_batch_norm_legit_no_training = None
            relu: f32[1, 64, 256, 256] = torch.ops.aten.relu.default(getitem);  getitem = None
            max_pool2d_with_indices = torch.ops.aten.max_pool2d_with_indices.default(relu, [3, 3], [2, 2], [1, 1]);  relu = None
            getitem_1: f32[1, 64, 128, 128] = max_pool2d_with_indices[0];  max_pool2d_with_indices = None
            convolution_12: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(getitem_1, arg9_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg9_1 = None
            _native_batch_norm_legit_no_training_1 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_12, arg10_1, arg11_1, arg95_1, arg96_1, 0.1, 1e-05);  convolution_12 = arg10_1 = arg11_1 = arg95_1 = arg96_1 = None
            getitem_2: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_1[0];  _native_batch_norm_legit_no_training_1 = None
            relu_1: f32[1, 64, 128, 128] = torch.ops.aten.relu.default(getitem_2);  getitem_2 = None
            convolution_13: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(relu_1, arg12_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_1 = arg12_1 = None
            _native_batch_norm_legit_no_training_2 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_13, arg13_1, arg14_1, arg98_1, arg99_1, 0.1, 1e-05);  convolution_13 = arg13_1 = arg14_1 = arg98_1 = arg99_1 = None
            getitem_3: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_2[0];  _native_batch_norm_legit_no_training_2 = None
            add_12: f32[1, 64, 128, 128] = torch.ops.aten.add.Tensor(getitem_3, getitem_1);  getitem_3 = getitem_1 = None
            relu_2: f32[1, 64, 128, 128] = torch.ops.aten.relu.default(add_12);  add_12 = None
            convolution_14: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(relu_2, arg15_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg15_1 = None
            _native_batch_norm_legit_no_training_3 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_14, arg16_1, arg17_1, arg101_1, arg102_1, 0.1, 1e-05);  convolution_14 = arg16_1 = arg17_1 = arg101_1 = arg102_1 = None
            getitem_4: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_3[0];  _native_batch_norm_legit_no_training_3 = None
            relu_3: f32[1, 64, 128, 128] = torch.ops.aten.relu.default(getitem_4);  getitem_4 = None
            convolution_15: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(relu_3, arg18_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_3 = arg18_1 = None
            _native_batch_norm_legit_no_training_4 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_15, arg19_1, arg20_1, arg104_1, arg105_1, 0.1, 1e-05);  convolution_15 = arg19_1 = arg20_1 = arg104_1 = arg105_1 = None
            getitem_5: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_4[0];  _native_batch_norm_legit_no_training_4 = None
            add_13: f32[1, 64, 128, 128] = torch.ops.aten.add.Tensor(getitem_5, relu_2);  getitem_5 = relu_2 = None
            relu_4: f32[1, 64, 128, 128] = torch.ops.aten.relu.default(add_13);  add_13 = None
            convolution_16: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_4, arg21_1, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1);  arg21_1 = None
            _native_batch_norm_legit_no_training_5 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_16, arg22_1, arg23_1, arg107_1, arg108_1, 0.1, 1e-05);  convolution_16 = arg22_1 = arg23_1 = arg107_1 = arg108_1 = None
            getitem_6: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_5[0];  _native_batch_norm_legit_no_training_5 = None
            relu_5: f32[1, 128, 64, 64] = torch.ops.aten.relu.default(getitem_6);  getitem_6 = None
            convolution_17: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_5, arg24_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_5 = arg24_1 = None
            _native_batch_norm_legit_no_training_6 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_17, arg25_1, arg26_1, arg110_1, arg111_1, 0.1, 1e-05);  convolution_17 = arg25_1 = arg26_1 = arg110_1 = arg111_1 = None
            getitem_7: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_6[0];  _native_batch_norm_legit_no_training_6 = None
            convolution_18: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_4, arg27_1, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  arg27_1 = None
            _native_batch_norm_legit_no_training_7 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_18, arg28_1, arg29_1, arg113_1, arg114_1, 0.1, 1e-05);  convolution_18 = arg28_1 = arg29_1 = arg113_1 = arg114_1 = None
            getitem_8: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_7[0];  _native_batch_norm_legit_no_training_7 = None
            add_14: f32[1, 128, 64, 64] = torch.ops.aten.add.Tensor(getitem_7, getitem_8);  getitem_7 = getitem_8 = None
            relu_6: f32[1, 128, 64, 64] = torch.ops.aten.relu.default(add_14);  add_14 = None
            convolution_19: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_6, arg30_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg30_1 = None
            _native_batch_norm_legit_no_training_8 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_19, arg31_1, arg32_1, arg116_1, arg117_1, 0.1, 1e-05);  convolution_19 = arg31_1 = arg32_1 = arg116_1 = arg117_1 = None
            getitem_9: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_8[0];  _native_batch_norm_legit_no_training_8 = None
            relu_7: f32[1, 128, 64, 64] = torch.ops.aten.relu.default(getitem_9);  getitem_9 = None
            convolution_20: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(relu_7, arg33_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_7 = arg33_1 = None
            _native_batch_norm_legit_no_training_9 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_20, arg34_1, arg35_1, arg119_1, arg120_1, 0.1, 1e-05);  convolution_20 = arg34_1 = arg35_1 = arg119_1 = arg120_1 = None
            getitem_10: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_9[0];  _native_batch_norm_legit_no_training_9 = None
            add_15: f32[1, 128, 64, 64] = torch.ops.aten.add.Tensor(getitem_10, relu_6);  getitem_10 = relu_6 = None
            relu_8: f32[1, 128, 64, 64] = torch.ops.aten.relu.default(add_15);  add_15 = None
            convolution_21: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_8, arg36_1, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1);  arg36_1 = None
            _native_batch_norm_legit_no_training_10 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_21, arg37_1, arg38_1, arg122_1, arg123_1, 0.1, 1e-05);  convolution_21 = arg37_1 = arg38_1 = arg122_1 = arg123_1 = None
            getitem_11: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_10[0];  _native_batch_norm_legit_no_training_10 = None
            relu_9: f32[1, 256, 32, 32] = torch.ops.aten.relu.default(getitem_11);  getitem_11 = None
            convolution_22: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_9, arg39_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_9 = arg39_1 = None
            _native_batch_norm_legit_no_training_11 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_22, arg40_1, arg41_1, arg125_1, arg126_1, 0.1, 1e-05);  convolution_22 = arg40_1 = arg41_1 = arg125_1 = arg126_1 = None
            getitem_12: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_11[0];  _native_batch_norm_legit_no_training_11 = None
            convolution_23: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_8, arg42_1, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  arg42_1 = None
            _native_batch_norm_legit_no_training_12 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_23, arg43_1, arg44_1, arg128_1, arg129_1, 0.1, 1e-05);  convolution_23 = arg43_1 = arg44_1 = arg128_1 = arg129_1 = None
            getitem_13: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_12[0];  _native_batch_norm_legit_no_training_12 = None
            add_16: f32[1, 256, 32, 32] = torch.ops.aten.add.Tensor(getitem_12, getitem_13);  getitem_12 = getitem_13 = None
            relu_10: f32[1, 256, 32, 32] = torch.ops.aten.relu.default(add_16);  add_16 = None
            convolution_24: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_10, arg45_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg45_1 = None
            _native_batch_norm_legit_no_training_13 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_24, arg46_1, arg47_1, arg131_1, arg132_1, 0.1, 1e-05);  convolution_24 = arg46_1 = arg47_1 = arg131_1 = arg132_1 = None
            getitem_14: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_13[0];  _native_batch_norm_legit_no_training_13 = None
            relu_11: f32[1, 256, 32, 32] = torch.ops.aten.relu.default(getitem_14);  getitem_14 = None
            convolution_25: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_11, arg48_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_11 = arg48_1 = None
            _native_batch_norm_legit_no_training_14 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_25, arg49_1, arg50_1, arg134_1, arg135_1, 0.1, 1e-05);  convolution_25 = arg49_1 = arg50_1 = arg134_1 = arg135_1 = None
            getitem_15: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_14[0];  _native_batch_norm_legit_no_training_14 = None
            add_17: f32[1, 256, 32, 32] = torch.ops.aten.add.Tensor(getitem_15, relu_10);  getitem_15 = relu_10 = None
            relu_12: f32[1, 256, 32, 32] = torch.ops.aten.relu.default(add_17);  add_17 = None
            convolution_26: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_12, arg51_1, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1);  arg51_1 = None
            _native_batch_norm_legit_no_training_15 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_26, arg52_1, arg53_1, arg137_1, arg138_1, 0.1, 1e-05);  convolution_26 = arg52_1 = arg53_1 = arg137_1 = arg138_1 = None
            getitem_16: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_15[0];  _native_batch_norm_legit_no_training_15 = None
            relu_13: f32[1, 512, 16, 16] = torch.ops.aten.relu.default(getitem_16);  getitem_16 = None
            convolution_27: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_13, arg54_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_13 = arg54_1 = None
            _native_batch_norm_legit_no_training_16 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_27, arg55_1, arg56_1, arg140_1, arg141_1, 0.1, 1e-05);  convolution_27 = arg55_1 = arg56_1 = arg140_1 = arg141_1 = None
            getitem_17: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_16[0];  _native_batch_norm_legit_no_training_16 = None
            convolution_28: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_12, arg57_1, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  arg57_1 = None
            _native_batch_norm_legit_no_training_17 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_28, arg58_1, arg59_1, arg143_1, arg144_1, 0.1, 1e-05);  convolution_28 = arg58_1 = arg59_1 = arg143_1 = arg144_1 = None
            getitem_18: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_17[0];  _native_batch_norm_legit_no_training_17 = None
            add_18: f32[1, 512, 16, 16] = torch.ops.aten.add.Tensor(getitem_17, getitem_18);  getitem_17 = getitem_18 = None
            relu_14: f32[1, 512, 16, 16] = torch.ops.aten.relu.default(add_18);  add_18 = None
            convolution_29: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_14, arg60_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  arg60_1 = None
            _native_batch_norm_legit_no_training_18 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_29, arg61_1, arg62_1, arg146_1, arg147_1, 0.1, 1e-05);  convolution_29 = arg61_1 = arg62_1 = arg146_1 = arg147_1 = None
            getitem_19: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_18[0];  _native_batch_norm_legit_no_training_18 = None
            relu_15: f32[1, 512, 16, 16] = torch.ops.aten.relu.default(getitem_19);  getitem_19 = None
            convolution_30: f32[1, 512, 16, 16] = torch.ops.aten.convolution.default(relu_15, arg63_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  relu_15 = arg63_1 = None
            _native_batch_norm_legit_no_training_19 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_30, arg64_1, arg65_1, arg149_1, arg150_1, 0.1, 1e-05);  convolution_30 = arg64_1 = arg65_1 = arg149_1 = arg150_1 = None
            getitem_20: f32[1, 512, 16, 16] = _native_batch_norm_legit_no_training_19[0];  _native_batch_norm_legit_no_training_19 = None
            add_19: f32[1, 512, 16, 16] = torch.ops.aten.add.Tensor(getitem_20, relu_14);  getitem_20 = relu_14 = None
            relu_16: f32[1, 512, 16, 16] = torch.ops.aten.relu.default(add_19);  add_19 = None
            convolution_31: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(relu_16, arg66_1, arg67_1, [2, 2], [0, 0], [1, 1], True, [0, 0], 1);  relu_16 = arg66_1 = arg67_1 = None
            cat_1: f32[1, 512, 32, 32] = torch.ops.aten.cat.default([relu_12, convolution_31], 1);  relu_12 = convolution_31 = None
            convolution_32: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(cat_1, arg68_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  cat_1 = arg68_1 = None
            _native_batch_norm_legit_no_training_20 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_32, arg69_1, arg70_1, arg152_1, arg153_1, 0.1, 1e-05);  convolution_32 = arg69_1 = arg70_1 = arg152_1 = arg153_1 = None
            getitem_21: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_20[0];  _native_batch_norm_legit_no_training_20 = None
            leaky_relu: f32[1, 256, 32, 32] = torch.ops.aten.leaky_relu.default(getitem_21);  getitem_21 = None
            convolution_33: f32[1, 256, 32, 32] = torch.ops.aten.convolution.default(leaky_relu, arg71_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  leaky_relu = arg71_1 = None
            _native_batch_norm_legit_no_training_21 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_33, arg72_1, arg73_1, arg155_1, arg156_1, 0.1, 1e-05);  convolution_33 = arg72_1 = arg73_1 = arg155_1 = arg156_1 = None
            getitem_22: f32[1, 256, 32, 32] = _native_batch_norm_legit_no_training_21[0];  _native_batch_norm_legit_no_training_21 = None
            leaky_relu_1: f32[1, 256, 32, 32] = torch.ops.aten.leaky_relu.default(getitem_22);  getitem_22 = None
            convolution_34: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(leaky_relu_1, arg74_1, arg75_1, [2, 2], [0, 0], [1, 1], True, [0, 0], 1);  leaky_relu_1 = arg74_1 = arg75_1 = None
            cat_2: f32[1, 256, 64, 64] = torch.ops.aten.cat.default([relu_8, convolution_34], 1);  relu_8 = convolution_34 = None
            convolution_35: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(cat_2, arg76_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  cat_2 = arg76_1 = None
            _native_batch_norm_legit_no_training_22 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_35, arg77_1, arg78_1, arg158_1, arg159_1, 0.1, 1e-05);  convolution_35 = arg77_1 = arg78_1 = arg158_1 = arg159_1 = None
            getitem_23: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_22[0];  _native_batch_norm_legit_no_training_22 = None
            leaky_relu_2: f32[1, 128, 64, 64] = torch.ops.aten.leaky_relu.default(getitem_23);  getitem_23 = None
            convolution_36: f32[1, 128, 64, 64] = torch.ops.aten.convolution.default(leaky_relu_2, arg79_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  leaky_relu_2 = arg79_1 = None
            _native_batch_norm_legit_no_training_23 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_36, arg80_1, arg81_1, arg161_1, arg162_1, 0.1, 1e-05);  convolution_36 = arg80_1 = arg81_1 = arg161_1 = arg162_1 = None
            getitem_24: f32[1, 128, 64, 64] = _native_batch_norm_legit_no_training_23[0];  _native_batch_norm_legit_no_training_23 = None
            leaky_relu_3: f32[1, 128, 64, 64] = torch.ops.aten.leaky_relu.default(getitem_24);  getitem_24 = None
            convolution_37: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(leaky_relu_3, arg82_1, arg83_1, [2, 2], [0, 0], [1, 1], True, [0, 0], 1);  leaky_relu_3 = arg82_1 = arg83_1 = None
            cat_3: f32[1, 128, 128, 128] = torch.ops.aten.cat.default([relu_4, convolution_37], 1);  relu_4 = convolution_37 = None
            convolution_38: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(cat_3, arg84_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  cat_3 = arg84_1 = None
            _native_batch_norm_legit_no_training_24 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_38, arg85_1, arg86_1, arg164_1, arg165_1, 0.1, 1e-05);  convolution_38 = arg85_1 = arg86_1 = arg164_1 = arg165_1 = None
            getitem_25: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_24[0];  _native_batch_norm_legit_no_training_24 = None
            leaky_relu_4: f32[1, 64, 128, 128] = torch.ops.aten.leaky_relu.default(getitem_25);  getitem_25 = None
            convolution_39: f32[1, 64, 128, 128] = torch.ops.aten.convolution.default(leaky_relu_4, arg87_1, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  leaky_relu_4 = arg87_1 = None
            _native_batch_norm_legit_no_training_25 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_39, arg88_1, arg89_1, arg167_1, arg168_1, 0.1, 1e-05);  convolution_39 = arg88_1 = arg89_1 = arg167_1 = arg168_1 = None
            getitem_26: f32[1, 64, 128, 128] = _native_batch_norm_legit_no_training_25[0];  _native_batch_norm_legit_no_training_25 = None
            leaky_relu_5: f32[1, 64, 128, 128] = torch.ops.aten.leaky_relu.default(getitem_26);  getitem_26 = None
            convolution_40: f32[1, 134, 128, 128] = torch.ops.aten.convolution.default(leaky_relu_5, arg90_1, arg91_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  leaky_relu_5 = arg90_1 = arg91_1 = None
            split_with_sizes = torch.ops.aten.split_with_sizes.default(convolution_40, [1, 33, 33, 33, 33, 1], 1);  convolution_40 = None
            getitem_27: f32[1, 1, 128, 128] = split_with_sizes[0]
            getitem_28: f32[1, 33, 128, 128] = split_with_sizes[1]
            getitem_29: f32[1, 33, 128, 128] = split_with_sizes[2]
            getitem_30: f32[1, 33, 128, 128] = split_with_sizes[3]
            getitem_31: f32[1, 33, 128, 128] = split_with_sizes[4]
            getitem_32: f32[1, 1, 128, 128] = split_with_sizes[5];  split_with_sizes = None
            sigmoid: f32[1, 1, 128, 128] = torch.ops.aten.sigmoid.default(getitem_27);  getitem_27 = None
            sigmoid_1: f32[1, 1, 128, 128] = torch.ops.aten.sigmoid.default(getitem_32);  getitem_32 = None
            alias: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sigmoid)
            mul_16: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sigmoid_1, alias);  sigmoid_1 = alias = None
            _tensor_constant0: i64[2] = self._tensor_constant0
            lift_fresh_copy: i64[2] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0);  _tensor_constant0 = None
            clone: i64[2] = torch.ops.aten.clone.default(lift_fresh_copy);  lift_fresh_copy = None
            _tensor_constant1: f64[] = self._tensor_constant1
            lift_fresh_copy_1: f64[] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant1);  _tensor_constant1 = None
            mul_17: f64[2] = torch.ops.aten.mul.Tensor(lift_fresh_copy_1, clone);  lift_fresh_copy_1 = clone = None
            _tensor_constant2: f64[] = self._tensor_constant2
            lift_fresh_copy_2: f64[] = torch.ops.aten.lift_fresh_copy.default(_tensor_constant2);  _tensor_constant2 = None
            div_4: f64[2] = torch.ops.aten.div.Tensor(mul_17, lift_fresh_copy_2);  mul_17 = lift_fresh_copy_2 = None
            _softmax: f32[1, 33, 128, 128] = torch.ops.aten._softmax.default(getitem_28, 1, False)
            _softmax_1: f32[1, 33, 128, 128] = torch.ops.aten._softmax.default(getitem_29, 1, False)
            _softmax_2: f32[1, 33, 128, 128] = torch.ops.aten._softmax.default(getitem_30, 1, False)
            _softmax_3: f32[1, 33, 128, 128] = torch.ops.aten._softmax.default(getitem_31, 1, False)
            arange_2: i64[34] = torch.ops.aten.arange.start_step(0, 34, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_2: b8[34] = torch.ops.aten.lt.Scalar(arange_2, 17.0)
            _to_copy_7: f32[34] = torch.ops.aten._to_copy.default(arange_2, dtype = torch.float32)
            mul_18: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_7, 0.06060606060606061);  _to_copy_7 = None
            add_20: f32[34] = torch.ops.aten.add.Tensor(mul_18, -1);  mul_18 = None
            sub_11: i64[34] = torch.ops.aten.sub.Tensor(33, arange_2);  arange_2 = None
            _to_copy_8: f32[34] = torch.ops.aten._to_copy.default(sub_11, dtype = torch.float32);  sub_11 = None
            mul_19: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_8, 0.06060606060606061);  _to_copy_8 = None
            sub_12: f32[34] = torch.ops.aten.sub.Tensor(1, mul_19);  mul_19 = None
            where_4: f32[34] = torch.ops.aten.where.self(lt_2, add_20, sub_12);  lt_2 = add_20 = sub_12 = None
            clamp_2: f32[34] = torch.ops.aten.clamp.default(where_4, -1, 1);  where_4 = None
            abs_2: f32[34] = torch.ops.aten.abs.default(clamp_2)
            sub_13: f32[34] = torch.ops.aten.sub.Tensor(2, abs_2);  abs_2 = None
            div_5: f32[34] = torch.ops.aten.div.Tensor(clamp_2, sub_13);  clamp_2 = sub_13 = None
            slice_1: f32[33] = torch.ops.aten.slice.Tensor(div_5, 0, 0, -1)
            slice_2: f32[33] = torch.ops.aten.slice.Tensor(div_5, 0, 1, 9223372036854775807);  div_5 = None
            add_21: f32[33] = torch.ops.aten.add.Tensor(slice_1, slice_2);  slice_1 = slice_2 = None
            div_6: f32[33] = torch.ops.aten.div.Tensor(add_21, 2);  add_21 = None
            view_2: f32[1, 33, 1, 1] = torch.ops.aten.view.default(div_6, [1, -1, 1, 1]);  div_6 = None
            select_3: f64[] = torch.ops.aten.select.int(div_4, 0, 1)
            mul_20: f32[1, 33, 1, 1] = torch.ops.aten.mul.Tensor(view_2, select_3);  view_2 = select_3 = None
            arange_3: i64[34] = torch.ops.aten.arange.start_step(0, 34, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_3: b8[34] = torch.ops.aten.lt.Scalar(arange_3, 17.0)
            _to_copy_9: f32[34] = torch.ops.aten._to_copy.default(arange_3, dtype = torch.float32)
            mul_21: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_9, 0.06060606060606061);  _to_copy_9 = None
            add_22: f32[34] = torch.ops.aten.add.Tensor(mul_21, -1);  mul_21 = None
            sub_14: i64[34] = torch.ops.aten.sub.Tensor(33, arange_3);  arange_3 = None
            _to_copy_10: f32[34] = torch.ops.aten._to_copy.default(sub_14, dtype = torch.float32);  sub_14 = None
            mul_22: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_10, 0.06060606060606061);  _to_copy_10 = None
            sub_15: f32[34] = torch.ops.aten.sub.Tensor(1, mul_22);  mul_22 = None
            where_5: f32[34] = torch.ops.aten.where.self(lt_3, add_22, sub_15);  lt_3 = add_22 = sub_15 = None
            clamp_3: f32[34] = torch.ops.aten.clamp.default(where_5, -1, 1);  where_5 = None
            abs_3: f32[34] = torch.ops.aten.abs.default(clamp_3)
            sub_16: f32[34] = torch.ops.aten.sub.Tensor(2, abs_3);  abs_3 = None
            div_7: f32[34] = torch.ops.aten.div.Tensor(clamp_3, sub_16);  clamp_3 = sub_16 = None
            slice_3: f32[33] = torch.ops.aten.slice.Tensor(div_7, 0, 0, -1)
            slice_4: f32[33] = torch.ops.aten.slice.Tensor(div_7, 0, 1, 9223372036854775807);  div_7 = None
            add_23: f32[33] = torch.ops.aten.add.Tensor(slice_3, slice_4);  slice_3 = slice_4 = None
            div_8: f32[33] = torch.ops.aten.div.Tensor(add_23, 2);  add_23 = None
            view_3: f32[1, 33, 1, 1] = torch.ops.aten.view.default(div_8, [1, -1, 1, 1]);  div_8 = None
            select_4: f64[] = torch.ops.aten.select.int(div_4, 0, 0);  div_4 = None
            mul_23: f32[1, 33, 1, 1] = torch.ops.aten.mul.Tensor(view_3, select_4);  view_3 = select_4 = None
            mul_24: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(_softmax, mul_20)
            sum_1: f32[1, 1, 128, 128] = torch.ops.aten.sum.dim_IntList(mul_24, [1], True);  mul_24 = None
            mul_25: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(_softmax_1, mul_23)
            sum_2: f32[1, 1, 128, 128] = torch.ops.aten.sum.dim_IntList(mul_25, [1], True);  mul_25 = None
            mul_26: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(_softmax_2, mul_20);  _softmax_2 = None
            sum_3: f32[1, 1, 128, 128] = torch.ops.aten.sum.dim_IntList(mul_26, [1], True);  mul_26 = None
            mul_27: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(_softmax_3, mul_23);  _softmax_3 = None
            sum_4: f32[1, 1, 128, 128] = torch.ops.aten.sum.dim_IntList(mul_27, [1], True);  mul_27 = None
            alias_1: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sum_2)
            mul_28: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sum_3, alias_1);  alias_1 = None
            alias_2: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sum_1)
            mul_29: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sum_4, alias_2);  alias_2 = None
            sub_17: f32[1, 1, 128, 128] = torch.ops.aten.sub.Tensor(mul_28, mul_29);  mul_28 = mul_29 = None
            alias_3: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sum_1)
            mul_30: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sum_3, alias_3);  alias_3 = None
            alias_4: f32[1, 1, 128, 128] = torch.ops.aten.alias.default(sum_2)
            mul_31: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sum_4, alias_4);  alias_4 = None
            add_24: f32[1, 1, 128, 128] = torch.ops.aten.add.Tensor(mul_30, mul_31);  mul_30 = mul_31 = None
            mul_32: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(sub_17, mul_16);  sub_17 = None
            mean_1: f32[1] = torch.ops.aten.mean.dim(mul_32, [1, 2, 3]);  mul_32 = None
            mul_33: f32[1, 1, 128, 128] = torch.ops.aten.mul.Tensor(add_24, mul_16);  add_24 = None
            mean_2: f32[1] = torch.ops.aten.mean.dim(mul_33, [1, 2, 3]);  mul_33 = None
            atan2: f32[1] = torch.ops.aten.atan2.default(mean_1, mean_2);  mean_1 = mean_2 = None
            mul_34: f32[1] = torch.ops.aten.mul.Tensor(atan2, 180);  atan2 = None
            div_9: f32[1] = torch.ops.aten.div.Tensor(mul_34, 3.141592653589793);  mul_34 = None
            arange_4: i64[128] = torch.ops.aten.arange.start_step(0, 128, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_4: b8[128] = torch.ops.aten.lt.Scalar(arange_4, 64.0)
            _to_copy_11: f32[128] = torch.ops.aten._to_copy.default(arange_4, dtype = torch.float32)
            mul_35: f32[128] = torch.ops.aten.mul.Tensor(_to_copy_11, 0.015748031496062992);  _to_copy_11 = None
            add_25: f32[128] = torch.ops.aten.add.Tensor(mul_35, -1);  mul_35 = None
            sub_18: i64[128] = torch.ops.aten.sub.Tensor(127, arange_4);  arange_4 = None
            _to_copy_12: f32[128] = torch.ops.aten._to_copy.default(sub_18, dtype = torch.float32);  sub_18 = None
            mul_36: f32[128] = torch.ops.aten.mul.Tensor(_to_copy_12, 0.015748031496062992);  _to_copy_12 = None
            sub_19: f32[128] = torch.ops.aten.sub.Tensor(1, mul_36);  mul_36 = None
            where_6: f32[128] = torch.ops.aten.where.self(lt_4, add_25, sub_19);  lt_4 = add_25 = sub_19 = None
            arange_5: i64[128] = torch.ops.aten.arange.start_step(0, 128, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_5: b8[128] = torch.ops.aten.lt.Scalar(arange_5, 64.0)
            _to_copy_13: f32[128] = torch.ops.aten._to_copy.default(arange_5, dtype = torch.float32)
            mul_37: f32[128] = torch.ops.aten.mul.Tensor(_to_copy_13, 0.015748031496062992);  _to_copy_13 = None
            add_26: f32[128] = torch.ops.aten.add.Tensor(mul_37, -1);  mul_37 = None
            sub_20: i64[128] = torch.ops.aten.sub.Tensor(127, arange_5);  arange_5 = None
            _to_copy_14: f32[128] = torch.ops.aten._to_copy.default(sub_20, dtype = torch.float32);  sub_20 = None
            mul_38: f32[128] = torch.ops.aten.mul.Tensor(_to_copy_14, 0.015748031496062992);  _to_copy_14 = None
            sub_21: f32[128] = torch.ops.aten.sub.Tensor(1, mul_38);  mul_38 = None
            where_7: f32[128] = torch.ops.aten.where.self(lt_5, add_26, sub_21);  lt_5 = add_26 = sub_21 = None
            view_4: f32[128, 1] = torch.ops.aten.view.default(where_6, [-1, 1]);  where_6 = None
            expand_2: f32[128, 128] = torch.ops.aten.expand.default(view_4, [128, 128]);  view_4 = None
            view_5: f32[1, 128] = torch.ops.aten.view.default(where_7, [1, -1]);  where_7 = None
            expand_3: f32[128, 128] = torch.ops.aten.expand.default(view_5, [128, 128]);  view_5 = None
            add_27: f32[128, 128] = torch.ops.aten.add.Tensor(expand_3, 1);  expand_3 = None
            div_10: f32[128, 128] = torch.ops.aten.div.Tensor(add_27, 2);  add_27 = None
            mul_39: f32[128, 128] = torch.ops.aten.mul.Tensor(div_10, 511);  div_10 = None
            add_28: f32[128, 128] = torch.ops.aten.add.Tensor(expand_2, 1);  expand_2 = None
            div_11: f32[128, 128] = torch.ops.aten.div.Tensor(add_28, 2);  add_28 = None
            mul_40: f32[128, 128] = torch.ops.aten.mul.Tensor(div_11, 511);  div_11 = None
            arange_6: i64[34] = torch.ops.aten.arange.start_step(0, 34, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_6: b8[34] = torch.ops.aten.lt.Scalar(arange_6, 17.0)
            _to_copy_15: f32[34] = torch.ops.aten._to_copy.default(arange_6, dtype = torch.float32)
            mul_41: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_15, 0.06060606060606061);  _to_copy_15 = None
            add_29: f32[34] = torch.ops.aten.add.Tensor(mul_41, -1);  mul_41 = None
            sub_22: i64[34] = torch.ops.aten.sub.Tensor(33, arange_6);  arange_6 = None
            _to_copy_16: f32[34] = torch.ops.aten._to_copy.default(sub_22, dtype = torch.float32);  sub_22 = None
            mul_42: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_16, 0.06060606060606061);  _to_copy_16 = None
            sub_23: f32[34] = torch.ops.aten.sub.Tensor(1, mul_42);  mul_42 = None
            where_8: f32[34] = torch.ops.aten.where.self(lt_6, add_29, sub_23);  lt_6 = add_29 = sub_23 = None
            clamp_4: f32[34] = torch.ops.aten.clamp.default(where_8, -1, 1);  where_8 = None
            abs_4: f32[34] = torch.ops.aten.abs.default(clamp_4)
            sub_24: f32[34] = torch.ops.aten.sub.Tensor(2, abs_4);  abs_4 = None
            div_12: f32[34] = torch.ops.aten.div.Tensor(clamp_4, sub_24);  clamp_4 = sub_24 = None
            slice_5: f32[33] = torch.ops.aten.slice.Tensor(div_12, 0, 1, 9223372036854775807)
            slice_6: f32[33] = torch.ops.aten.slice.Tensor(div_12, 0, 0, -1);  div_12 = None
            sub_25: f32[33] = torch.ops.aten.sub.Tensor(slice_5, slice_6);  slice_5 = slice_6 = None
            view_6: f32[1, 33, 1, 1] = torch.ops.aten.view.default(sub_25, [1, -1, 1, 1]);  sub_25 = None
            arange_7: i64[34] = torch.ops.aten.arange.start_step(0, 34, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
            lt_7: b8[34] = torch.ops.aten.lt.Scalar(arange_7, 17.0)
            _to_copy_17: f32[34] = torch.ops.aten._to_copy.default(arange_7, dtype = torch.float32)
            mul_43: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_17, 0.06060606060606061);  _to_copy_17 = None
            add_30: f32[34] = torch.ops.aten.add.Tensor(mul_43, -1);  mul_43 = None
            sub_26: i64[34] = torch.ops.aten.sub.Tensor(33, arange_7);  arange_7 = None
            _to_copy_18: f32[34] = torch.ops.aten._to_copy.default(sub_26, dtype = torch.float32);  sub_26 = None
            mul_44: f32[34] = torch.ops.aten.mul.Tensor(_to_copy_18, 0.06060606060606061);  _to_copy_18 = None
            sub_27: f32[34] = torch.ops.aten.sub.Tensor(1, mul_44);  mul_44 = None
            where_9: f32[34] = torch.ops.aten.where.self(lt_7, add_30, sub_27);  lt_7 = add_30 = sub_27 = None
            clamp_5: f32[34] = torch.ops.aten.clamp.default(where_9, -1, 1);  where_9 = None
            abs_5: f32[34] = torch.ops.aten.abs.default(clamp_5)
            sub_28: f32[34] = torch.ops.aten.sub.Tensor(2, abs_5);  abs_5 = None
            div_13: f32[34] = torch.ops.aten.div.Tensor(clamp_5, sub_28);  clamp_5 = sub_28 = None
            slice_7: f32[33] = torch.ops.aten.slice.Tensor(div_13, 0, 1, 9223372036854775807)
            slice_8: f32[33] = torch.ops.aten.slice.Tensor(div_13, 0, 0, -1);  div_13 = None
            sub_29: f32[33] = torch.ops.aten.sub.Tensor(slice_7, slice_8);  slice_7 = slice_8 = None
            view_7: f32[1, 33, 1, 1] = torch.ops.aten.view.default(sub_29, [1, -1, 1, 1]);  sub_29 = None
            clamp_min_2: f32[1, 33, 1, 1] = torch.ops.aten.clamp_min.default(view_6, 1e-06);  view_6 = None
            div_14: f32[1, 33, 128, 128] = torch.ops.aten.div.Tensor(_softmax, clamp_min_2);  _softmax = clamp_min_2 = None
            clamp_min_3: f32[1, 33, 1, 1] = torch.ops.aten.clamp_min.default(view_7, 1e-06);  view_7 = None
            div_15: f32[1, 33, 128, 128] = torch.ops.aten.div.Tensor(_softmax_1, clamp_min_3);  _softmax_1 = clamp_min_3 = None
            unsqueeze_2: f32[1, 128, 128] = torch.ops.aten.unsqueeze.default(mul_39, 0);  mul_39 = None
            unsqueeze_3: f32[1, 1, 128, 128] = torch.ops.aten.unsqueeze.default(unsqueeze_2, 1);  unsqueeze_2 = None
            add_31: f32[1, 33, 128, 128] = torch.ops.aten.add.Tensor(mul_20, unsqueeze_3);  mul_20 = unsqueeze_3 = None
            unsqueeze_4: f32[1, 128, 128] = torch.ops.aten.unsqueeze.default(mul_40, 0);  mul_40 = None
            unsqueeze_5: f32[1, 1, 128, 128] = torch.ops.aten.unsqueeze.default(unsqueeze_4, 1);  unsqueeze_4 = None
            add_32: f32[1, 33, 128, 128] = torch.ops.aten.add.Tensor(mul_23, unsqueeze_5);  mul_23 = unsqueeze_5 = None
            mul_45: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(add_31, div_14);  add_31 = None
            mul_46: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(mul_45, sigmoid);  mul_45 = None
            sum_5: f32[1] = torch.ops.aten.sum.dim_IntList(mul_46, [1, 2, 3]);  mul_46 = None
            mul_47: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(div_14, sigmoid);  div_14 = None
            sum_6: f32[1] = torch.ops.aten.sum.dim_IntList(mul_47, [1, 2, 3]);  mul_47 = None
            clamp_min_4: f32[1] = torch.ops.aten.clamp_min.default(sum_6, 1e-06);  sum_6 = None
            div_16: f32[1] = torch.ops.aten.div.Tensor(sum_5, clamp_min_4);  sum_5 = clamp_min_4 = None
            mul_48: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(add_32, div_15);  add_32 = None
            mul_49: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(mul_48, sigmoid);  mul_48 = None
            sum_7: f32[1] = torch.ops.aten.sum.dim_IntList(mul_49, [1, 2, 3]);  mul_49 = None
            mul_50: f32[1, 33, 128, 128] = torch.ops.aten.mul.Tensor(div_15, sigmoid);  div_15 = None
            sum_8: f32[1] = torch.ops.aten.sum.dim_IntList(mul_50, [1, 2, 3]);  mul_50 = None
            clamp_min_5: f32[1] = torch.ops.aten.clamp_min.default(sum_8, 1e-06);  sum_8 = None
            div_17: f32[1] = torch.ops.aten.div.Tensor(sum_7, clamp_min_5);  sum_7 = clamp_min_5 = None
            slice_9: f32[1] = torch.ops.aten.slice.Tensor(div_16, 0, 0, 9223372036854775807);  div_16 = None
            unsqueeze_6: f32[1, 1] = torch.ops.aten.unsqueeze.default(slice_9, 1);  slice_9 = None
            slice_10: f32[1] = torch.ops.aten.slice.Tensor(div_17, 0, 0, 9223372036854775807);  div_17 = None
            unsqueeze_7: f32[1, 1] = torch.ops.aten.unsqueeze.default(slice_10, 1);  slice_10 = None
            slice_11: f32[1] = torch.ops.aten.slice.Tensor(div_9, 0, 0, 9223372036854775807);  div_9 = None
            unsqueeze_8: f32[1, 1] = torch.ops.aten.unsqueeze.default(slice_11, 1);  slice_11 = None
            cat_4: f32[1, 3] = torch.ops.aten.cat.default([unsqueeze_6, unsqueeze_7, unsqueeze_8], 1);  unsqueeze_6 = unsqueeze_7 = unsqueeze_8 = None
            return (getitem_28, getitem_29, getitem_30, getitem_31, cat_4, sigmoid, clamp_1, sum_3, sum_4, mul_16, sum_1, sum_2)

Graph Signature: **Removed this to spare some lines**
Symbol to range: {}

Unfortunatelly this architecture is resulting in the error: torch.onnx._internal.diagnostics.infra.context.RuntimeErrorWithDiagnostic: Unsupported FX nodes: {'call_function': ['aten.roll.default', 'aten.var.correction']}.

Is there any guideline on how to solve this problem and implement the support for the aforementioned operations? Thank you, and sorry for the long post.

justinchuby commented 10 months ago

Thanks for reporting this!

The decomposition for all aten operators are under https://github.com/microsoft/onnxscript/tree/main/onnxscript/function_libs/torch_lib/ops

For aten.roll, we need a variation of https://github.com/microsoft/onnxscript/blob/b7f215ea130e455bd0cfe999d551389bc0718489/onnxscript/function_libs/torch_lib/ops/core.py#L6897 to handle complex inputs.

For aten.var.correction we need to implement https://github.com/microsoft/onnxscript/blob/b7f215ea130e455bd0cfe999d551389bc0718489/onnxscript/function_libs/torch_lib/ops/core.py#L8261

justinchuby commented 10 months ago

Please follow this guide https://github.com/microsoft/onnxscript/wiki/TorchLib-function-authoring-guide if you would like to contribute. Thank you!

luisfmnunes commented 9 months ago

@justinchuby, Thank you for all provided material.

The implementation from aten::roll for complex type seems to have worked fine because all tests seems to have passed (I basically copied the structure from the real one and applied the roll to each channel [real and imag] individually and then concatenated).

onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_complex_output_match_opinfo__roll_cpu_complex64 PASSED                                                                                                                                                                  [  7%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_bool PASSED                                                                                                                                                                               [ 14%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_int64 PASSED                                                                                                                                                                              [ 21%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_int32 PASSED                                                                                                                                                                              [ 28%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_float16 PASSED                                                                                                                                                                            [ 35%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyEagerCPU::test_output_match_opinfo__roll_cpu_float32 PASSED                                                                                                                                                                            [ 42%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_complex_output_match_opinfo__roll_cpu_complex64 PASSED                                                                                                                                                              [ 50%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_int64 PASSED                                                                                                                                                                          [ 57%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_bool PASSED                                                                                                                                                                           [ 64%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_int32 PASSED                                                                                                                                                                          [ 71%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_float16 PASSED                                                                                                                                                                        [ 78%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__roll_cpu_float32 PASSED                                                                                                                                                                        [ 85%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestFunctionValidity::test_function_has_op_schema_315_aten_roll PASSED                                                                                                                                                                                      [ 92%]
onnxscript/tests/function_libs/torch_lib/ops_test.py::TestFunctionValidity::test_function_has_op_schema_316_aten_roll_complex PASSED  

I can't say the same about the aten::var though. I basically used the same logic applied to aten::var_mean and got the same skipped and xfail tests, like onnxscript/tests/function_libs/torch_lib/ops_test.py::TestOutputConsistencyFullGraphCPU::test_output_match_opinfo__var_unbiased_cpu_float32 SKIPPED (Skip: fixme: Inferred shape and existing shape differ in rank). It seems like the new scope of torch.var is always being called, ignoring old (unbiased) function scope. It is a strange behaviour because the following test seems to have called the new scope, giving True from unbiased to dims (which is implicitly converted to 1) and for some reason the correction receives a value of 5 (coincidence with the input shape?).

test

Using numpy, with the same input I obtained the same Actual Output values using a correction of 5.0 numpy equivalent.

Now it succeds when converting to ONNX with the following script:


    state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))    
    # model.load_state_dict(state_dict["model"])
    model.eval()

    logger.info(f"Exporting model to {onnx_model}")
    # export(
    #     model,
    #     torch.randn(1, 1, 512, 512),
    #     onnx_model,
    #     export_params=True,
    #     do_constant_folding=True,
    #     opset_version=18,
    #     input_names=["image"],
    #     output_names=["center", "grid", "pose_2d", "seg", "img_sup", "seg_sup"],
    #     dynamic_axes={
    #         'image':{0: "batch_size", 2: "height", 3: "width"},
    #         'center': {0: "batch_size"},
    #         "grid": {0: "batch_size"},
    #         "pose_2d": {0: "batch_size"},
    #         "seg": {0: "batch_size"},
    #         "img_sup": {0: "batch_size"},
    #         "seg_sup": {0: "batch_size"}
    #     }
    # )

    # print(torch.export.export(model, (torch.randn(1, 1, 512, 512, dtype=torch.float32),)))
    # TorchDynamo exports correctly but there are still unsupported onnxscript ops.
    # Waiting Issue Response to solve this.

    onnx_prog = torch.onnx.dynamo_export(model, torch.randn(1, 1, 512, 512, dtype=torch.float32))
    onnx_prog.save(onnx_model.as_posix(), model_state_dict=state_dict["model"])

    onnx_model = onnx.load(onnx_model)
    onnx.checker.check_model(onnx_model)
    print(onnx.helper.printable_graph(onnx_model.graph))

But I get the following Warnings:

/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/onnx/_internal/exporter.py:130: UserWarning: torch.onnx.dynamo_export only implements opset version 18 for now. If you need to use a different opset version, please register them with register_custom_op.
  warnings.warn(
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/onnx/_internal/fx/passes/readability.py:53: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  new_node = self.module.graph.get_attr(normalized_name)
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer0_1_running_mean target layer0/1/running_mean layer0/1/running_mean of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer0_1_running_var target layer0/1/running_var layer0/1/running_var of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_0_bn1_running_mean target layer1/1/0/bn1/running_mean layer1/1/0/bn1/running_mean of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_0_bn1_running_var target layer1/1/0/bn1/running_var layer1/1/0/bn1/running_var of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_0_bn2_running_mean target layer1/1/0/bn2/running_mean layer1/1/0/bn2/running_mean of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_0_bn2_running_var target layer1/1/0/bn2/running_var layer1/1/0/bn2/running_var of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_1_bn1_running_mean target layer1/1/1/bn1/running_mean layer1/1/1/bn1/running_mean of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_1_bn1_running_var target layer1/1/1/bn1/running_var layer1/1/1/bn1/running_var of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_1_bn2_running_mean target layer1/1/1/bn2/running_mean layer1/1/1/bn2/running_mean of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer1_1_1_bn2_running_var target layer1/1/1/bn2/running_var layer1/1/1/bn2/running_var of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer2_0_bn1_running_mean target layer2/0/bn1/running_mean layer2/0/bn1/running_mean of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer2_0_bn1_running_var target layer2/0/bn1/running_var layer2/0/bn1/running_var of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node layer2_0_bn2_running_mean target layer2/0/bn2/running_mean layer2/0/bn2/running_mean of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

# Many repeated Warnings with the same signature.

/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/fx/graph.py:1377: UserWarning: Node decoder_layer2_conv_4_running_var target decoder/layer2/conv/4/running_var decoder/layer2/conv/4/running_var of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '
/home/griaule/Melo/onnxscript/onnxscript/function_libs/torch_lib/graph_building.py:971: UserWarning: ONNX model is invalid: [ShapeInferenceError] (op_type:Add, node name: Add_56): [TypeInferenceError] Inferred elem type differs from existing elem type: (7) vs (1)
  warnings.warn(f"ONNX model is invalid: {e}", stacklevel=1)

and when I load the model with onnx.load and use onnx.checker.check_model, it raises the following exception:

Traceback (most recent call last):
  File "/home/griaule/Fingerprint-2DPose-Dense-Voting/model2onnx.py", line 88, in <module>
    main(parse_args())
  File "/home/griaule/Fingerprint-2DPose-Dense-Voting/model2onnx.py", line 75, in main
    onnx.checker.check_model(onnx_model)
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnx/checker.py", line 148, in check_model
    C.check_model(protobuf_string, full_check, skip_opset_compatibility_check)
onnx.onnx_cpp2py_export.checker.ValidationError: preprocess_tv.img_grad.weight_x initializer name is not unique

Yet, when I try to load it as an inference model for onnxruntime, I get the following exception:

>>> import onnxruntime as ort
>>> ort.InferenceSession("out/20231119_175136/best.onnx")
2023-11-27 17:05:35.103468663 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'preprocess_tv.img_grad.weight_x' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103490273 [W:onnxruntime:, graph.cc:1283 Graph] Initializer preprocess_tv.img_grad.weight_x appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103494888 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'preprocess_tv.img_grad.weight_y' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103498240 [W:onnxruntime:, graph.cc:1283 Graph] Initializer preprocess_tv.img_grad.weight_y appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103501922 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'input_layer.1.weight_avg' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103505156 [W:onnxruntime:, graph.cc:1283 Graph] Initializer input_layer.1.weight_avg appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103508790 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'input_layer.1.conv_grad.weight_x' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103511971 [W:onnxruntime:, graph.cc:1283 Graph] Initializer input_layer.1.conv_grad.weight_x appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103515655 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'input_layer.1.conv_grad.weight_y' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103518761 [W:onnxruntime:, graph.cc:1283 Graph] Initializer input_layer.1.conv_grad.weight_y appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.103522334 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'input_layer.1.conv_gaussian.gkern2d' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.103525487 [W:onnxruntime:, graph.cc:1283 Graph] Initializer input_layer.1.conv_gaussian.gkern2d appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.

# Many Repeated warnings for different Parameters

2023-11-27 17:05:35.105126349 [W:onnxruntime:, graph.cc:1283 Graph] Initializer pixels_out.weight appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
2023-11-27 17:05:35.105131066 [W:onnxruntime:, graph.cc:1256 Graph] Duplicate initializer (dense, sparse or ConstantNode): 'pixels_out.bias' the model will use the latest encountered initializer. Please, fix your model.
2023-11-27 17:05:35.105135000 [W:onnxruntime:, graph.cc:1283 Graph] Initializer pixels_out.bias appears in graph inputs and will not be treated as constant value/weight. This may prevent some of the graph optimizations, like const folding. Move it out of graph inputs if there is no need to override it, by either re-generating the model with latest exporter/converter or with the tool onnxruntime/tools/python/remove_initializer_from_input.py.
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from out/20231119_175136/best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_0) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_15): B has inconsistent type tensor(int64

I also used the onnx.helper.printable_graph tool to export the graph structure of the ONNX protobuf. The result is in onnx_graph.txt

Sorry for the long response. I'm kind lost on how to proceed. Thank you.

justinchuby commented 9 months ago

Thanks for doing the experiments! Let's solve this in a few steps.

  1. for aten::var, if you could create a pull request with your implementation and tests, we can look together to make sure it is correct.
  2. The warnings are from PyTorch dynamo. I would ignore them until we solve the rest and still aren't able to get the model correct
  3. The model check error may be related to https://github.com/microsoft/onnxscript/pull/1184? Try installing onnxscript from the github main branch. If not this may be a bug in our implementation. I will look into this.
  4. We can care about the ORT type error later. This may be a type promotion cases we are not handling correctly in the converter, but it could also be due to other things that can be solved with (1)
luisfmnunes commented 9 months ago

I also used ONNX GraphSurgeon to check on types, seems like all the _val_ are set to None or 0 as dtype (not sure if this is expected during conversion and what might cause it). Here(onnx_graph_data.txt) is the output I logged from all nodes of the graph.

justinchuby commented 9 months ago

can you share the onnx model itself? You may zip it and attach it here.

luisfmnunes commented 9 months ago

I submitted a PR #1186. The model is greater than 25MB, github is complaining about uploading so I uploaded it on driver.

luisfmnunes commented 9 months ago

Thanks for doing the experiments! Let's solve this in a few steps.

  1. for aten::var, if you could create a pull request with your implementation and tests, we can look together to make sure it is correct.
  2. The warnings are from PyTorch dynamo. I would ignore them until we solve the rest and still aren't able to get the model correct
  3. The model check error may be related to Fix value_info names in symbolic shape export #1184? Try installing onnxscript from the github main branch. If not this may be a bug in our implementation. I will look into this.
  4. We can care about the ORT type error later. This may be a type promotion cases we are not handling correctly in the converter, but it could also be due to other things that can be solved with (1)

Ok, so I was able to get rid of the duplicated parameters warnings and the onnx.checker.check_model exception by removing the state_load from onnx_prog.save() and loading the parameters previously with model.load_state_dict. The code bellow works for model conversion.

    state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))    
    model.load_state_dict(state_dict["model"])
    model.eval()

    logger.info(f"Exporting model to {onnx_model}")
    # export(
    #     model,
    #     torch.randn(1, 1, 512, 512),
    #     onnx_model,
    #     export_params=True,
    #     do_constant_folding=True,
    #     opset_version=18,
    #     input_names=["image"],
    #     output_names=["center", "grid", "pose_2d", "seg", "img_sup", "seg_sup"],
    #     dynamic_axes={
    #         'image':{0: "batch_size", 2: "height", 3: "width"},
    #         'center': {0: "batch_size"},
    #         "grid": {0: "batch_size"},
    #         "pose_2d": {0: "batch_size"},
    #         "seg": {0: "batch_size"},
    #         "img_sup": {0: "batch_size"},
    #         "seg_sup": {0: "batch_size"}
    #     }
    # )

    # print(torch.export.export(model, (torch.randn(1, 1, 512, 512, dtype=torch.float32),)))
    # TorchDynamo exports correctly but there are still unsupported onnxscript ops.
    # Waiting Issue Response to solve this.

    onnx_prog = torch.onnx.dynamo_export(model, torch.randn(1, 1, 512, 512, dtype=torch.float32))
    onnx_prog.save(onnx_model.as_posix())

    onnx_model = onnx.load(onnx_model)
    onnx.checker.check_model(onnx_model)
    print(onnx.helper.printable_graph(onnx_model.graph))

The problem is that the model is still incorrect showing the warning /home/griaule/Melo/onnxscript/onnxscript/function_libs/torch_lib/graph_building.py:971: UserWarning: ONNX model is invalid: [ShapeInferenceError] (op_type:Add, node name: Add_56): [TypeInferenceError] Inferred elem type differs from existing elem type: (7) vs (1) on torch.onnx.dynamo_export. When I try to use the model on ONNX Runtime, it stills raises the problem regarding inconsistent type

>>> import onnxruntime as ort
>>> ort.InferenceSession("best.onnx")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_0) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_15): B has inconsistent type tensor(int64)

Bellow is the class of the FastCartoonTexture nn.Module that seems to be crashing:


class ImageGradient(nn.Module):
    def __init__(self):
        super().__init__()
        kernel_x = [[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]]
        kernel_x = torch.FloatTensor(kernel_x).unsqueeze(0).unsqueeze(0)
        kernel_y = [[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]]
        kernel_y = torch.FloatTensor(kernel_y).unsqueeze(0).unsqueeze(0)
        self.weight_x = nn.Parameter(data=kernel_x, requires_grad=False)
        self.weight_y = nn.Parameter(data=kernel_y, requires_grad=False)

    def forward(self, x):
        grad_x = F.conv2d(x, self.weight_x, padding=1)
        grad_y = F.conv2d(x, self.weight_y, padding=1)
        return grad_x, grad_y

class FastCartoonTexture(nn.Module):
    def __init__(self, sigma=2.5, eps=1e-6) -> None:
        super().__init__()
        self.sigma = sigma
        self.eps = eps
        self.cmin = 0.3
        self.cmax = 0.7
        self.lim = 20

        self.img_grad = ImageGradient()

    def lowpass_filtering(self, img, L):
        img_fft = torch.fft.fftshift(torch.fft.fft2(img), dim=(-2, -1)) * L

        img_rec = torch.fft.ifft2(torch.fft.fftshift(img_fft, dim=(-2, -1)))
        img_rec = torch.real(img_rec)

        return img_rec

    def gradient_norm(self, img):
        Gx, Gy = self.img_grad(img)
        return torch.sqrt(Gx ** 2 + Gy ** 2) + self.eps

    def forward(self, input):
        H, W = input.size(-2), input.size(-1)
        grid_y, grid_x = torch.meshgrid(torch.linspace(-0.5, 0.5, H), torch.linspace(-0.5, 0.5, W), indexing="ij")
        grid_radius = torch.sqrt(grid_x ** 2 + grid_y ** 2) + self.eps

        L = (1.0 / (1 + (2 * np.pi * grid_radius * self.sigma) ** 4)).type_as(input)[None, None]

        grad_img1 = self.gradient_norm(input)
        grad_img1 = self.lowpass_filtering(grad_img1, L)

        img_low = self.lowpass_filtering(input, L)
        grad_img2 = self.gradient_norm(img_low)
        grad_img2 = self.lowpass_filtering(grad_img2, L)

        diff = grad_img1 - grad_img2
        flag = torch.abs(grad_img1)
        diff = torch.where(flag > 1, diff / flag.clamp_min(self.eps), torch.zeros_like(diff))

        weight = (diff - self.cmin) / (self.cmax - self.cmin)
        weight = torch.clamp(weight, 0, 1)

        cartoon = weight * img_low + (1 - weight) * input
        texture = (input - cartoon + self.lim) * 255 / (2 * self.lim)
        texture = torch.clamp(texture, 0, 255)
        return texture
justinchuby commented 9 months ago

Could you share your pytorch version? I would make sure it is the latest torch-nightly build.

luisfmnunes commented 9 months ago

This might be the problem. Using the package dunder I got the following:

>>> torch.__version__
'2.1.1+cu118'

I`ll try it out tomorrow using the latest torch-nightly build.

luisfmnunes commented 9 months ago

I installed the torch-nightly build and still the same Warning regarding ShapeInferenceError and TypeInferenceError. The only difference now is that the indexing of the nodes changed

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_149) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_148): B has inconsistent type tensor(int64
luisfmnunes commented 9 months ago

I changed the backbock from a manually built resnet to the backbone from timm and retrained the model. The new model architecture seems to have a way straighter graph (checked on Netron), but it still has the same problem occuring in Div_148. Here is the new model.

Python 3.11.4 (main, Jul  5 2023, 13:45:01) [GCC 11.2.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import onnxruntime as ort
>>> ort.InferenceSession("best.onnx")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 452, in _create_inference_session
    sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Load model from best.onnx failed:Node (models_units_FastCartoonTexture_preprocess_tv_1_7) Op (models_units_FastCartoonTexture_preprocess_tv_1) [ShapeInferenceError] (op_type:Div, node name: Div_148): B has inconsistent type tensor(int64)
>>> 

Also noticed the following UserWarning from torch.onnx.dynamo_export that might be giving a clue that something is wrong during graph build.

/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/torch/onnx/_internal/fx/passes/readability.py:53: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  new_node = self.module.graph.get_attr(normalized_name)
justinchuby commented 9 months ago

Looks like we need a castlike here

https://github.com/microsoft/onnxscript/blob/9e7485861040a5e9770e2dac5fc5d7d6ceba9f2b/onnxscript/function_libs/torch_lib/ops/core.py#L4533-L4537

cc @fatcat-z

justinchuby commented 9 months ago

@luisfmnunes do you think the model layer is using linspace there?

luisfmnunes commented 9 months ago

@luisfmnunes do you think the model layer is using linspace there?

Most Likely, the nn.Module that seems to be the problem is the FastCartoonTexture class and if we look at its forward function, meshgrid is built based on linspace tensor.

class FastCartoonTexture(nn.Module):
    def __init__(self, sigma=2.5, eps=1e-6) -> None:
        super().__init__()
        self.sigma = sigma
        self.eps = eps
        self.cmin = 0.3
        self.cmax = 0.7
        self.lim = 20

        self.img_grad = ImageGradient()

    def lowpass_filtering(self, img, L):
        img_fft = torch.fft.fftshift(torch.fft.fft2(img), dim=(-2, -1)) * L

        img_rec = torch.fft.ifft2(torch.fft.fftshift(img_fft, dim=(-2, -1)))
        img_rec = torch.real(img_rec)

        return img_rec

    def gradient_norm(self, img):
        Gx, Gy = self.img_grad(img)
        return torch.sqrt(Gx ** 2 + Gy ** 2) + self.eps

    def forward(self, input):
        H, W = input.size(-2), input.size(-1)
        grid_y, grid_x = torch.meshgrid(torch.linspace(-0.5, 0.5, H), torch.linspace(-0.5, 0.5, W), indexing="ij")
        grid_radius = torch.sqrt(grid_x ** 2 + grid_y ** 2).type_as(input) + self.eps

        L = (1.0 / (1 + (2 * np.pi * grid_radius * self.sigma) ** 4)).type_as(input)[None, None]

        grad_img1 = self.gradient_norm(input)
        grad_img1 = self.lowpass_filtering(grad_img1, L)

        img_low = self.lowpass_filtering(input, L)
        grad_img2 = self.gradient_norm(img_low)
        grad_img2 = self.lowpass_filtering(grad_img2, L)

        diff = grad_img1 - grad_img2
        flag = torch.abs(grad_img1)
        diff = torch.where(flag > 1, diff / flag.clamp_min(self.eps), torch.zeros_like(diff))

        weight = (diff - self.cmin) / (self.cmax - self.cmin)
        weight = torch.clamp(weight, 0, 1)

        cartoon = weight * img_low + (1 - weight) * input
        texture = (input - cartoon + self.lim) * 255 / (2 * self.lim)
        texture = torch.clamp(texture, 0, 255)
        return texture
justinchuby commented 9 months ago

Great - we will need to fix linspace. If you like please feel free to add a cast like in the lines above and see if it gives you the correct model. I will create a fix this week.

luisfmnunes commented 9 months ago

Great - we will need to fix linspace. If you like please feel free to add a cast like in the lines above and see if it gives you the correct model. I will create a fix this week.

Cool, the CastLike solves the problem of Div, but right after the Mul on return also has inconsistent type (due to start and range_tensors type mismatch probably). if I cast the range_tensor also like start, the following error happens later on the graph:

>>> ort.InferenceSession("best.onnx")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 419, in __init__
    self._create_inference_session(providers, provider_options, disabled_optimizers)
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 463, in _create_inference_session
    sess.initialize_session(providers, provider_options, disabled_optimizers)
onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : FAIL : Node (_inline_models_model_zoo_DenseHoughVoter_voter_1aten_rsub_186) Op (aten_rsub) [ShapeInferenceError] (op_type:Sub, node name: n3): B has inconsistent type tensor(int64)

Edit: The problem is still the linspace. When I gave the keyword dtype to linspace functions I got a correct model.

Thank you very much for all your attention and help @justinchuby. I'll check what is failing on my PR and try to contribute with the aten::roll for complex and the var (dim and correction).

justinchuby commented 9 months ago

FYI if you find the functions with too many if branches are bothering you because of performance, you may consider https://github.com/microsoft/onnxscript/pull/1178

luisfmnunes commented 9 months ago

Well, despite the correctness of the model (meaning it has a valid graph) it seems like the parameters of the model were not loaded. Either if a I load the weights previously using nn.Module.load_state_dict or passing it to torch.onnx.dynamo_export kwargs.The results seems preety much random and diverging a lot from the torch model output.

image

justinchuby commented 9 months ago

Could you share a script you use for export and comparison?

luisfmnunes commented 9 months ago

Sure.

The export script:

import torch
import onnx
import yaml
from torch.onnx import export
from pathlib import Path
from loguru import logger
from argparse import ArgumentParser

from models.model_zoo import GRIDNET4, GRIDTIMMNET4

def main(args):

    torch.set_default_device("cpu")

    root = args.model_dir
    config_file = root / "configs.yaml"
    checkpoint = root / args.model_version
    logger.info(f"Exporting model {root} to ONNX")

    onnx_model = checkpoint.parent / checkpoint.with_suffix(".onnx").name

    logger.info(f"Loading Model config file {config_file}")
    with open(config_file, "r") as f:
        config = yaml.load(f.read(), yaml.Loader)

    if config["exp_name"] == "gridnet4":
        model = GRIDNET4(
            num_pose_2d=config["num_pose_2d"],
            num_layers=config["num_layers"],
            img_ppi=config["img_ppi"],
            middle_shape=config["middle_shape"],
            with_tv=config["with_tv"],
            with_enh=config["with_enh"],
            bin_type=config["bin_type"],
            activate=config["activate"],
            pretrained=False
        )
    else:
        model = GRIDTIMMNET4(**config)

    logger.info(f"Loading model checkpoint {checkpoint}")

    state_dict = torch.load(checkpoint, map_location=torch.device('cpu'))

    # model.load_state_dict(state_dict["model"]) # Load model state_dict previously
    model.eval()

    logger.info(f"Exporting model to {onnx_model}")

    # TorchDynamo exports correctly but there are still unsupported onnxscript ops.
    # Waiting Issue Response to solve this.

    # onnx_export_options = torch.onnx.ExportOptions(dynamic_shapes=True)

    onnx_prog = torch.onnx.dynamo_export(
        model, torch.randn(1, 1, 512, 512, dtype=torch.float32), 
        # export_options=onnx_export_options
    )
    # onnx_prog.save(onnx_model.as_posix()) #version without model dict load
    onnx_prog.save(onnx_model.as_posix(), model_state_dict=state_dict["model"])

    onnx_model = onnx.load(onnx_model)
    print(onnx.helper.printable_graph(onnx_model.graph))
    onnx.checker.check_model(onnx_model)

def parse_args():

    parser = ArgumentParser()

    parser.add_argument("model_dir", help="Parent directory from model", type=Path)
    parser.add_argument("--model_version", help="Version of exported checkpoint", type=str, default="best.pth")

    return parser.parse_args()

if __name__ == "__main__":
    main(parse_args())

The Comparison Script (Edit added function to compare the parameters)

import cv2
import yaml
import torch
import onnx
import numpy as np
import onnxruntime as ort

from pathlib import Path
from loguru import logger
from argparse import ArgumentParser

from deploy_gridnet import process_img
from models.model_zoo import GRIDNET4, GRIDTIMMNET4

def get_onnx_tensor_dict(onnx_load):
    return {t.name: onnx.numpy_helper.to_array(t) for t in onnx_load.graph.initializer}

def compare_onnx_graph_and_state_dict(onnx_dict, state_dict):
    torch_keys = [k for k in state_dict.keys() if k not in onnx_dict]
    onnx_keys = [k for k in onnx_dict.keys() if k not in state_dict]
    for k, v in onnx_dict.items():
        if k in onnx_keys: continue
        is_close = np.isclose(
            v,
            state_dict[k].numpy()
        )
        if not is_close.all():
            logger.warning(
                f"Parameter {k} is Divirging. {is_close}"
            )

    logger.warning(f"ONNX Keys not in PyTorch {onnx_keys}")
    logger.warning(f"PyTorch Keys not in ONNX {torch_keys}")

def main(args):

    config_file = args.root / "configs.yaml"
    logger.info(f"Reading config file {config_file}")
    with open(config_file, "r") as f:
        config = yaml.load(f, yaml.Loader)

    if config.get("architecture", None):
        model = GRIDTIMMNET4(**config)
    else:
        model = GRIDNET4(**config)

    checkpoint_file = args.root / f"{args.name}.pth"
    onnx_file = checkpoint_file.with_suffix(".onnx")

    logger.info(f"Loading Checkpoint {checkpoint_file}")
    checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint["model"])
    model.eval()

    logger.info(f"Loading ONNX Model {onnx_file}")
    onnx_load = onnx.load(onnx_file)
    onnx_model = ort.InferenceSession(onnx_file)

    logger.info(f"Comparing ONNX Graph and PyTorch State Dict")
    compare_onnx_graph_and_state_dict(
        get_onnx_tensor_dict(onnx_load),
        checkpoint["model"],
    )

    image_file = Path(__file__).parent / "image/1_1.tif"
    logger.info(f"Loading image {image_file}")
    # im = cv2.imread(
    #     image_file.as_posix(), 
    #     cv2.IMREAD_GRAYSCALE,
    # ).astype(np.float32)
    im = np.random.randn(512, 512).astype(np.float32)
    # im, _, _ = process_img(im, 500, None)
    print(im.shape)
    im_tensor = torch.from_numpy(im)[None][None]
    print(im_tensor.shape)
    with torch.no_grad():
        torch_out = model(im_tensor)
    onnx_out = onnx_model.run(["cat_4"], {"l_input_": im[None, None]}) 

    print(torch_out["pose_2d"])
    print(onnx_out)

def parse_args():

    parser = ArgumentParser()

    parser.add_argument("root", type=Path, help="Path to output models")
    parser.add_argument(
        "--name", "-n", type=str, help="Name of model file", default="best"
    )

    return parser.parse_args()

if __name__ == "__main__":
    main(parse_args())

Edit: Checking manually some parameters, they seem to have been converted fine. I'll make a script to verify if everything is in order.

Edit2 : Seems like there are no deviations in parameters, leading me to believe that the graph is incorrect. Will have to verify it further.

Edit3: I always get the same result independently of the input in the ONNX Model. I also get a broadcast error if I set dynamic shapes and give a shape different from the one used in export function ([1, 1, 512, 512]). I'll upload the model if you wish to check it out. Model

(384, 384)
torch.Size([1, 1, 384, 384])
2023-11-29 16:50:38.636816288 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Mul node. Name:'_inline_aten_mul_complex_token_108n19' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 384 by 512

Traceback (most recent call last):
  File "/home/griaule/Fingerprint-2DPose-Dense-Voting/onnx_sanity_check.py", line 104, in <module>
    main(parse_args())
  File "/home/griaule/Fingerprint-2DPose-Dense-Voting/onnx_sanity_check.py", line 78, in main
    onnx_out = onnx_model.run(
               ^^^^^^^^^^^^^^^
  File "/home/griaule/miniconda3/envs/torch2.0/lib/python3.11/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Mul node. Name:'_inline_aten_mul_complex_token_108n19' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 384 by 512

Edit4: Comparing the ONNX Graph and the State Dict, the following result is seen: image

Edit5: Seems like I screwed up some reference in the aten_linspace using CastLike. When I removed the changes from aten_linspace and forced torch.float32 in the linspace calls, I got a result very similar to PyTorch.

justinchuby commented 9 months ago

So we are good?

luisfmnunes commented 9 months ago

So we are good?

Yeah we good, there might be a problem in complex operations which is mentioned in Edit4, but it is all good, if eventually I need it to handle any shape without preprocessing, I open a new Issue. I'm closing this now.

Thank you once again @justinchuby.