pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 478 forks source link

Slower inference performance when switching from `torchxla_trace_once` to `openxla` compile backend #5430

Open GleasonK opened 1 year ago

GleasonK commented 1 year ago

🐛 Bug

It looks like torchxla_trace_once is deprecated in favor of openxla, but when I tried to make that migration in some benchmark testing I saw a new warning message and some performance regressions. This was found when running an inference benchmark from openxla-benchmark - ResNet on GPU.

To Reproduce

Colab repro.

Steps to reproduce the behavior:

  1. Run colab with torchxla_trace_once - should dump files.
  2. Run colab with openxla - should dump files (restart runtime if it does not)

Hopefully that provides enough information to be useful, if not I am happy to help further.

Expected behavior

On-par performance and HLO graph generation between the two backends (openxla and torchxla_trace_once).

Environment

Additional context

Output traces: save_ir.zip

JackCaoG commented 1 year ago

Thanks! I took a quick look at the HLO dump and the openxla(which is aot_torchxla_trace_once renamed)'s HLO is longer than torch_xla_trace_once. @wconstab @shunting314 Do you guys know what's the additional stuff that aot backend do? I thought if it is just inference the fx graph should be the same.

wconstab commented 1 year ago

I'm not too familiar with what the aot_ backend does differently. What's the history of that backend (is it based on the dynamo backend shunting wrote or something else?)

JackCaoG commented 1 year ago

open_xla == aot_torchxla_trace_once, and the only difference between torchxla_trace_once and aot_torchxla_trace_once is the aot_autograd

aot_torchxla_trace_once = aot_autograd(
    fw_compiler=torchxla_trace_once,
)

https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/backends/torchxla.py#L49-L51

I guess the question is, will aot_autograd do anything if it is only the fwd function of the model being called?

shunting314 commented 1 year ago

Maybe due to functionalization: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L1534 . It's done even for inference.

JackCaoG commented 1 year ago

I will take a look at fx graph

JackCaoG commented 1 year ago

I dumped the fx graph for resnet 18 using the openxla(aot) and torchxla_trace_once. Their FX graph looks quite different

openxla

def forward(self, primals_61, primals_123, primals_1, primals_2, primals_3, primals_63, primals_64, primals_4, primals_5, primals_6, primals_66, primals_67, primals_7, primals_8, primals_9, primals_69, primals_70, primals_10, primals_11, primals_12, primals_72, primals_73, primals_13, primals_14, primals_15, primals_75, primals_76, primals_16, primals_22, primals_17, primals_18, primals_78, primals_79, primals_23, primals_24, primals_84, primals_85, primals_19, primals_20, primals_21, primals_81, primals_82, primals_25, primals_26, primals_27, primals_87, primals_88, primals_28, primals_29, primals_30, primals_90, primals_91, primals_31, primals_37, primals_32, primals_33, primals_93, primals_94, primals_38, primals_39, primals_99, primals_100, primals_34, primals_35, primals_36, primals_96, primals_97, primals_40, primals_41, primals_42, primals_102, primals_103, primals_43, primals_44, primals_45, primals_105, primals_106, primals_46, primals_52, primals_47, primals_48, primals_108, primals_109, primals_53, primals_54, primals_114, primals_115, primals_49, primals_50, primals_51, primals_111, primals_112, primals_55, primals_56, primals_57, primals_117, primals_118, primals_58, primals_59, primals_60, primals_120, primals_121, primals_62):
    t = torch.ops.aten.t.default(primals_61);  primals_61 = None
    convolution = torch.ops.aten.convolution.default(primals_123, primals_1, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1);  primals_123 = primals_1 = None
    _native_batch_norm_legit_no_training = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution, primals_2, primals_3, primals_63, primals_64, 0.1, 1e-05);  primals_2 = primals_3 = primals_63 = primals_64 = None
    getitem = _native_batch_norm_legit_no_training[0]
    getitem_1 = _native_batch_norm_legit_no_training[1]
    getitem_2 = _native_batch_norm_legit_no_training[2];  _native_batch_norm_legit_no_training = None
    relu = torch.ops.aten.relu.default(getitem);  getitem = None
    max_pool2d_forward = torch.ops.xla.max_pool2d_forward.default(relu, [3, 3], [2, 2], [1, 1])
    convolution_1 = torch.ops.aten.convolution.default(max_pool2d_forward, primals_4, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_4 = None
    _native_batch_norm_legit_no_training_1 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_1, primals_5, primals_6, primals_66, primals_67, 0.1, 1e-05);  primals_5 = primals_6 = primals_66 = primals_67 = None
    getitem_3 = _native_batch_norm_legit_no_training_1[0]
    getitem_4 = _native_batch_norm_legit_no_training_1[1]
    getitem_5 = _native_batch_norm_legit_no_training_1[2];  _native_batch_norm_legit_no_training_1 = None
    relu_1 = torch.ops.aten.relu.default(getitem_3);  getitem_3 = None
    convolution_2 = torch.ops.aten.convolution.default(relu_1, primals_7, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_7 = None
    _native_batch_norm_legit_no_training_2 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_2, primals_8, primals_9, primals_69, primals_70, 0.1, 1e-05);  primals_8 = primals_9 = primals_69 = primals_70 = None
    getitem_6 = _native_batch_norm_legit_no_training_2[0]
    getitem_7 = _native_batch_norm_legit_no_training_2[1]
    getitem_8 = _native_batch_norm_legit_no_training_2[2];  _native_batch_norm_legit_no_training_2 = None
    add = torch.ops.aten.add.Tensor(getitem_6, max_pool2d_forward);  getitem_6 = None
    relu_2 = torch.ops.aten.relu.default(add);  add = None
    convolution_3 = torch.ops.aten.convolution.default(relu_2, primals_10, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_10 = None
    _native_batch_norm_legit_no_training_3 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_3, primals_11, primals_12, primals_72, primals_73, 0.1, 1e-05);  primals_11 = primals_12 = primals_72 = primals_73 = None
    getitem_9 = _native_batch_norm_legit_no_training_3[0]
    getitem_10 = _native_batch_norm_legit_no_training_3[1]
    getitem_11 = _native_batch_norm_legit_no_training_3[2];  _native_batch_norm_legit_no_training_3 = None
    relu_3 = torch.ops.aten.relu.default(getitem_9);  getitem_9 = None
    convolution_4 = torch.ops.aten.convolution.default(relu_3, primals_13, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_13 = None
    _native_batch_norm_legit_no_training_4 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_4, primals_14, primals_15, primals_75, primals_76, 0.1, 1e-05);  primals_14 = primals_15 = primals_75 = primals_76 = None
    getitem_12 = _native_batch_norm_legit_no_training_4[0]
    getitem_13 = _native_batch_norm_legit_no_training_4[1]
    getitem_14 = _native_batch_norm_legit_no_training_4[2];  _native_batch_norm_legit_no_training_4 = None
    add_1 = torch.ops.aten.add.Tensor(getitem_12, relu_2);  getitem_12 = None
    relu_4 = torch.ops.aten.relu.default(add_1);  add_1 = None
    convolution_5 = torch.ops.aten.convolution.default(relu_4, primals_16, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1);  primals_16 = None
    convolution_7 = torch.ops.aten.convolution.default(relu_4, primals_22, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  primals_22 = None
    _native_batch_norm_legit_no_training_5 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_5, primals_17, primals_18, primals_78, primals_79, 0.1, 1e-05);  primals_17 = primals_18 = primals_78 = primals_79 = None
    _native_batch_norm_legit_no_training_7 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_7, primals_23, primals_24, primals_84, primals_85, 0.1, 1e-05);  primals_23 = primals_24 = primals_84 = primals_85 = None
    getitem_15 = _native_batch_norm_legit_no_training_5[0]
    getitem_16 = _native_batch_norm_legit_no_training_5[1]
    getitem_17 = _native_batch_norm_legit_no_training_5[2];  _native_batch_norm_legit_no_training_5 = None
    getitem_21 = _native_batch_norm_legit_no_training_7[0]
    getitem_22 = _native_batch_norm_legit_no_training_7[1]
    getitem_23 = _native_batch_norm_legit_no_training_7[2];  _native_batch_norm_legit_no_training_7 = None
    relu_5 = torch.ops.aten.relu.default(getitem_15);  getitem_15 = None
    convolution_6 = torch.ops.aten.convolution.default(relu_5, primals_19, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_19 = None
    _native_batch_norm_legit_no_training_6 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_6, primals_20, primals_21, primals_81, primals_82, 0.1, 1e-05);  primals_20 = primals_21 = primals_81 = primals_82 = None
    getitem_18 = _native_batch_norm_legit_no_training_6[0]
    getitem_19 = _native_batch_norm_legit_no_training_6[1]
    getitem_20 = _native_batch_norm_legit_no_training_6[2];  _native_batch_norm_legit_no_training_6 = None
    add_2 = torch.ops.aten.add.Tensor(getitem_18, getitem_21);  getitem_18 = getitem_21 = None
    relu_6 = torch.ops.aten.relu.default(add_2);  add_2 = None
    convolution_8 = torch.ops.aten.convolution.default(relu_6, primals_25, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_25 = None
    _native_batch_norm_legit_no_training_8 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_8, primals_26, primals_27, primals_87, primals_88, 0.1, 1e-05);  primals_26 = primals_27 = primals_87 = primals_88 = None
    getitem_24 = _native_batch_norm_legit_no_training_8[0]
    getitem_25 = _native_batch_norm_legit_no_training_8[1]
    getitem_26 = _native_batch_norm_legit_no_training_8[2];  _native_batch_norm_legit_no_training_8 = None
    relu_7 = torch.ops.aten.relu.default(getitem_24);  getitem_24 = None
    convolution_9 = torch.ops.aten.convolution.default(relu_7, primals_28, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_28 = None
    _native_batch_norm_legit_no_training_9 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_9, primals_29, primals_30, primals_90, primals_91, 0.1, 1e-05);  primals_29 = primals_30 = primals_90 = primals_91 = None
    getitem_27 = _native_batch_norm_legit_no_training_9[0]
    getitem_28 = _native_batch_norm_legit_no_training_9[1]
    getitem_29 = _native_batch_norm_legit_no_training_9[2];  _native_batch_norm_legit_no_training_9 = None
    add_3 = torch.ops.aten.add.Tensor(getitem_27, relu_6);  getitem_27 = None
    relu_8 = torch.ops.aten.relu.default(add_3);  add_3 = None
    convolution_10 = torch.ops.aten.convolution.default(relu_8, primals_31, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1);  primals_31 = None
    convolution_12 = torch.ops.aten.convolution.default(relu_8, primals_37, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  primals_37 = None
    _native_batch_norm_legit_no_training_10 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_10, primals_32, primals_33, primals_93, primals_94, 0.1, 1e-05);  primals_32 = primals_33 = primals_93 = primals_94 = None
    _native_batch_norm_legit_no_training_12 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_12, primals_38, primals_39, primals_99, primals_100, 0.1, 1e-05);  primals_38 = primals_39 = primals_99 = primals_100 = None
    getitem_30 = _native_batch_norm_legit_no_training_10[0]
    getitem_31 = _native_batch_norm_legit_no_training_10[1]
    getitem_32 = _native_batch_norm_legit_no_training_10[2];  _native_batch_norm_legit_no_training_10 = None
    getitem_36 = _native_batch_norm_legit_no_training_12[0]
    getitem_37 = _native_batch_norm_legit_no_training_12[1]
    getitem_38 = _native_batch_norm_legit_no_training_12[2];  _native_batch_norm_legit_no_training_12 = None
    relu_9 = torch.ops.aten.relu.default(getitem_30);  getitem_30 = None
    convolution_11 = torch.ops.aten.convolution.default(relu_9, primals_34, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_34 = None
    _native_batch_norm_legit_no_training_11 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_11, primals_35, primals_36, primals_96, primals_97, 0.1, 1e-05);  primals_35 = primals_36 = primals_96 = primals_97 = None
    getitem_33 = _native_batch_norm_legit_no_training_11[0]
    getitem_34 = _native_batch_norm_legit_no_training_11[1]
    getitem_35 = _native_batch_norm_legit_no_training_11[2];  _native_batch_norm_legit_no_training_11 = None
    add_4 = torch.ops.aten.add.Tensor(getitem_33, getitem_36);  getitem_33 = getitem_36 = None
    relu_10 = torch.ops.aten.relu.default(add_4);  add_4 = None
    convolution_13 = torch.ops.aten.convolution.default(relu_10, primals_40, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_40 = None
    _native_batch_norm_legit_no_training_13 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_13, primals_41, primals_42, primals_102, primals_103, 0.1, 1e-05);  primals_41 = primals_42 = primals_102 = primals_103 = None
    getitem_39 = _native_batch_norm_legit_no_training_13[0]
    getitem_40 = _native_batch_norm_legit_no_training_13[1]
    getitem_41 = _native_batch_norm_legit_no_training_13[2];  _native_batch_norm_legit_no_training_13 = None
    relu_11 = torch.ops.aten.relu.default(getitem_39);  getitem_39 = None
    convolution_14 = torch.ops.aten.convolution.default(relu_11, primals_43, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_43 = None
    _native_batch_norm_legit_no_training_14 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_14, primals_44, primals_45, primals_105, primals_106, 0.1, 1e-05);  primals_44 = primals_45 = primals_105 = primals_106 = None
    getitem_42 = _native_batch_norm_legit_no_training_14[0]
    getitem_43 = _native_batch_norm_legit_no_training_14[1]
    getitem_44 = _native_batch_norm_legit_no_training_14[2];  _native_batch_norm_legit_no_training_14 = None
    add_5 = torch.ops.aten.add.Tensor(getitem_42, relu_10);  getitem_42 = None
    relu_12 = torch.ops.aten.relu.default(add_5);  add_5 = None
    convolution_15 = torch.ops.aten.convolution.default(relu_12, primals_46, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1);  primals_46 = None
    convolution_17 = torch.ops.aten.convolution.default(relu_12, primals_52, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  primals_52 = None
    _native_batch_norm_legit_no_training_15 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_15, primals_47, primals_48, primals_108, primals_109, 0.1, 1e-05);  primals_47 = primals_48 = primals_108 = primals_109 = None
    _native_batch_norm_legit_no_training_17 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_17, primals_53, primals_54, primals_114, primals_115, 0.1, 1e-05);  primals_53 = primals_54 = primals_114 = primals_115 = None
    getitem_45 = _native_batch_norm_legit_no_training_15[0]
    getitem_46 = _native_batch_norm_legit_no_training_15[1]
    getitem_47 = _native_batch_norm_legit_no_training_15[2];  _native_batch_norm_legit_no_training_15 = None
    getitem_51 = _native_batch_norm_legit_no_training_17[0]
    getitem_52 = _native_batch_norm_legit_no_training_17[1]
    getitem_53 = _native_batch_norm_legit_no_training_17[2];  _native_batch_norm_legit_no_training_17 = None
    relu_13 = torch.ops.aten.relu.default(getitem_45);  getitem_45 = None
    convolution_16 = torch.ops.aten.convolution.default(relu_13, primals_49, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_49 = None
    _native_batch_norm_legit_no_training_16 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_16, primals_50, primals_51, primals_111, primals_112, 0.1, 1e-05);  primals_50 = primals_51 = primals_111 = primals_112 = None
    getitem_48 = _native_batch_norm_legit_no_training_16[0]
    getitem_49 = _native_batch_norm_legit_no_training_16[1]
    getitem_50 = _native_batch_norm_legit_no_training_16[2];  _native_batch_norm_legit_no_training_16 = None
    add_6 = torch.ops.aten.add.Tensor(getitem_48, getitem_51);  getitem_48 = getitem_51 = None
    relu_14 = torch.ops.aten.relu.default(add_6);  add_6 = None
    convolution_18 = torch.ops.aten.convolution.default(relu_14, primals_55, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_55 = None
    _native_batch_norm_legit_no_training_18 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_18, primals_56, primals_57, primals_117, primals_118, 0.1, 1e-05);  primals_56 = primals_57 = primals_117 = primals_118 = None
    getitem_54 = _native_batch_norm_legit_no_training_18[0]
    getitem_55 = _native_batch_norm_legit_no_training_18[1]
    getitem_56 = _native_batch_norm_legit_no_training_18[2];  _native_batch_norm_legit_no_training_18 = None
    relu_15 = torch.ops.aten.relu.default(getitem_54);  getitem_54 = None
    convolution_19 = torch.ops.aten.convolution.default(relu_15, primals_58, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1);  primals_58 = None
    _native_batch_norm_legit_no_training_19 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_19, primals_59, primals_60, primals_120, primals_121, 0.1, 1e-05);  primals_59 = primals_60 = primals_120 = primals_121 = None
    getitem_57 = _native_batch_norm_legit_no_training_19[0]
    getitem_58 = _native_batch_norm_legit_no_training_19[1]
    getitem_59 = _native_batch_norm_legit_no_training_19[2];  _native_batch_norm_legit_no_training_19 = None
    add_7 = torch.ops.aten.add.Tensor(getitem_57, relu_14);  getitem_57 = None
    relu_16 = torch.ops.aten.relu.default(add_7);  add_7 = None
    mean = torch.ops.aten.mean.dim(relu_16, [-1, -2], True)
    view = torch.ops.aten.view.default(mean, [4, 512]);  mean = None
    addmm = torch.ops.aten.addmm.default(primals_62, view, t);  primals_62 = None
    return (t, convolution, getitem_1, getitem_2, relu, max_pool2d_forward, convolution_1, getitem_4, getitem_5, relu_1, convolution_2, getitem_7, getitem_8, relu_2, convolution_3, getitem_10, getitem_11, relu_3, convolution_4, getitem_13, getitem_14, relu_4, convolution_5, convolution_7, getitem_16, getitem_17, getitem_22, getitem_23, relu_5, convolution_6, getitem_19, getitem_20, relu_6, convolution_8, getitem_25, getitem_26, relu_7, convolution_9, getitem_28, getitem_29, relu_8, convolution_10, convolution_12, getitem_31, getitem_32, getitem_37, getitem_38, relu_9, convolution_11, getitem_34, getitem_35, relu_10, convolution_13, getitem_40, getitem_41, relu_11, convolution_14, getitem_43, getitem_44, relu_12, convolution_15, convolution_17, getitem_46, getitem_47, getitem_52, getitem_53, relu_13, convolution_16, getitem_49, getitem_50, relu_14, convolution_18, getitem_55, getitem_56, relu_15, convolution_19, getitem_58, getitem_59, relu_16, view, addmm)

torchxla_trace_once

def forward(self, l_x_ : torch.Tensor):
    l__self___conv1 = self.L__self___conv1(l_x_);  l_x_ = None
    l__self___bn1 = self.L__self___bn1(l__self___conv1);  l__self___conv1 = None
    l__self___relu = self.L__self___relu(l__self___bn1);  l__self___bn1 = None
    l__self___maxpool = self.L__self___maxpool(l__self___relu);  l__self___relu = None
    getattr_l__self___layer1___0___conv1 = self.getattr_L__self___layer1___0___conv1(l__self___maxpool)
    getattr_l__self___layer1___0___bn1 = self.getattr_L__self___layer1___0___bn1(getattr_l__self___layer1___0___conv1);  getattr_l__self___layer1___0___conv1 = None
    getattr_l__self___layer1___0___relu = self.getattr_L__self___layer1___0___relu(getattr_l__self___layer1___0___bn1);  getattr_l__self___layer1___0___bn1 = None
    getattr_l__self___layer1___0___conv2 = self.getattr_L__self___layer1___0___conv2(getattr_l__self___layer1___0___relu);  getattr_l__self___layer1___0___relu = None
    getattr_l__self___layer1___0___bn2 = self.getattr_L__self___layer1___0___bn2(getattr_l__self___layer1___0___conv2);  getattr_l__self___layer1___0___conv2 = None
    getattr_l__self___layer1___0___bn2 += l__self___maxpool;  iadd = getattr_l__self___layer1___0___bn2;  getattr_l__self___layer1___0___bn2 = l__self___maxpool = None
    getattr_l__self___layer1___0___relu_1 = self.getattr_L__self___layer1___0___relu(iadd);  iadd = None
    getattr_l__self___layer1___1___conv1 = self.getattr_L__self___layer1___1___conv1(getattr_l__self___layer1___0___relu_1)
    getattr_l__self___layer1___1___bn1 = self.getattr_L__self___layer1___1___bn1(getattr_l__self___layer1___1___conv1);  getattr_l__self___layer1___1___conv1 = None
    getattr_l__self___layer1___1___relu = self.getattr_L__self___layer1___1___relu(getattr_l__self___layer1___1___bn1);  getattr_l__self___layer1___1___bn1 = None
    getattr_l__self___layer1___1___conv2 = self.getattr_L__self___layer1___1___conv2(getattr_l__self___layer1___1___relu);  getattr_l__self___layer1___1___relu = None
    getattr_l__self___layer1___1___bn2 = self.getattr_L__self___layer1___1___bn2(getattr_l__self___layer1___1___conv2);  getattr_l__self___layer1___1___conv2 = None
    getattr_l__self___layer1___1___bn2 += getattr_l__self___layer1___0___relu_1;  iadd_1 = getattr_l__self___layer1___1___bn2;  getattr_l__self___layer1___1___bn2 = getattr_l__self___layer1___0___relu_1 = None
    getattr_l__self___layer1___1___relu_1 = self.getattr_L__self___layer1___1___relu(iadd_1);  iadd_1 = None
    getattr_l__self___layer2___0___conv1 = self.getattr_L__self___layer2___0___conv1(getattr_l__self___layer1___1___relu_1)
    getattr_l__self___layer2___0___downsample_0 = self.getattr_L__self___layer2___0___downsample_0(getattr_l__self___layer1___1___relu_1);  getattr_l__self___layer1___1___relu_1 = None
    getattr_l__self___layer2___0___bn1 = self.getattr_L__self___layer2___0___bn1(getattr_l__self___layer2___0___conv1);  getattr_l__self___layer2___0___conv1 = None
    getattr_l__self___layer2___0___downsample_1 = self.getattr_L__self___layer2___0___downsample_1(getattr_l__self___layer2___0___downsample_0);  getattr_l__self___layer2___0___downsample_0 = None
    getattr_l__self___layer2___0___relu = self.getattr_L__self___layer2___0___relu(getattr_l__self___layer2___0___bn1);  getattr_l__self___layer2___0___bn1 = None
    getattr_l__self___layer2___0___conv2 = self.getattr_L__self___layer2___0___conv2(getattr_l__self___layer2___0___relu);  getattr_l__self___layer2___0___relu = None
    getattr_l__self___layer2___0___bn2 = self.getattr_L__self___layer2___0___bn2(getattr_l__self___layer2___0___conv2);  getattr_l__self___layer2___0___conv2 = None
    getattr_l__self___layer2___0___bn2 += getattr_l__self___layer2___0___downsample_1;  iadd_2 = getattr_l__self___layer2___0___bn2;  getattr_l__self___layer2___0___bn2 = getattr_l__self___layer2___0___downsample_1 = None
    getattr_l__self___layer2___0___relu_1 = self.getattr_L__self___layer2___0___relu(iadd_2);  iadd_2 = None
    getattr_l__self___layer2___1___conv1 = self.getattr_L__self___layer2___1___conv1(getattr_l__self___layer2___0___relu_1)
    getattr_l__self___layer2___1___bn1 = self.getattr_L__self___layer2___1___bn1(getattr_l__self___layer2___1___conv1);  getattr_l__self___layer2___1___conv1 = None
    getattr_l__self___layer2___1___relu = self.getattr_L__self___layer2___1___relu(getattr_l__self___layer2___1___bn1);  getattr_l__self___layer2___1___bn1 = None
    getattr_l__self___layer2___1___conv2 = self.getattr_L__self___layer2___1___conv2(getattr_l__self___layer2___1___relu);  getattr_l__self___layer2___1___relu = None
    getattr_l__self___layer2___1___bn2 = self.getattr_L__self___layer2___1___bn2(getattr_l__self___layer2___1___conv2);  getattr_l__self___layer2___1___conv2 = None
    getattr_l__self___layer2___1___bn2 += getattr_l__self___layer2___0___relu_1;  iadd_3 = getattr_l__self___layer2___1___bn2;  getattr_l__self___layer2___1___bn2 = getattr_l__self___layer2___0___relu_1 = None
    getattr_l__self___layer2___1___relu_1 = self.getattr_L__self___layer2___1___relu(iadd_3);  iadd_3 = None
    getattr_l__self___layer3___0___conv1 = self.getattr_L__self___layer3___0___conv1(getattr_l__self___layer2___1___relu_1)
    getattr_l__self___layer3___0___downsample_0 = self.getattr_L__self___layer3___0___downsample_0(getattr_l__self___layer2___1___relu_1);  getattr_l__self___layer2___1___relu_1 = None
    getattr_l__self___layer3___0___bn1 = self.getattr_L__self___layer3___0___bn1(getattr_l__self___layer3___0___conv1);  getattr_l__self___layer3___0___conv1 = None
    getattr_l__self___layer3___0___downsample_1 = self.getattr_L__self___layer3___0___downsample_1(getattr_l__self___layer3___0___downsample_0);  getattr_l__self___layer3___0___downsample_0 = None
    getattr_l__self___layer3___0___relu = self.getattr_L__self___layer3___0___relu(getattr_l__self___layer3___0___bn1);  getattr_l__self___layer3___0___bn1 = None
    getattr_l__self___layer3___0___conv2 = self.getattr_L__self___layer3___0___conv2(getattr_l__self___layer3___0___relu);  getattr_l__self___layer3___0___relu = None
    getattr_l__self___layer3___0___bn2 = self.getattr_L__self___layer3___0___bn2(getattr_l__self___layer3___0___conv2);  getattr_l__self___layer3___0___conv2 = None
    getattr_l__self___layer3___0___bn2 += getattr_l__self___layer3___0___downsample_1;  iadd_4 = getattr_l__self___layer3___0___bn2;  getattr_l__self___layer3___0___bn2 = getattr_l__self___layer3___0___downsample_1 = None
    getattr_l__self___layer3___0___relu_1 = self.getattr_L__self___layer3___0___relu(iadd_4);  iadd_4 = None
    getattr_l__self___layer3___1___conv1 = self.getattr_L__self___layer3___1___conv1(getattr_l__self___layer3___0___relu_1)
    getattr_l__self___layer3___1___bn1 = self.getattr_L__self___layer3___1___bn1(getattr_l__self___layer3___1___conv1);  getattr_l__self___layer3___1___conv1 = None
    getattr_l__self___layer3___1___relu = self.getattr_L__self___layer3___1___relu(getattr_l__self___layer3___1___bn1);  getattr_l__self___layer3___1___bn1 = None
    getattr_l__self___layer3___1___conv2 = self.getattr_L__self___layer3___1___conv2(getattr_l__self___layer3___1___relu);  getattr_l__self___layer3___1___relu = None
    getattr_l__self___layer3___1___bn2 = self.getattr_L__self___layer3___1___bn2(getattr_l__self___layer3___1___conv2);  getattr_l__self___layer3___1___conv2 = None
    getattr_l__self___layer3___1___bn2 += getattr_l__self___layer3___0___relu_1;  iadd_5 = getattr_l__self___layer3___1___bn2;  getattr_l__self___layer3___1___bn2 = getattr_l__self___layer3___0___relu_1 = None
    getattr_l__self___layer3___1___relu_1 = self.getattr_L__self___layer3___1___relu(iadd_5);  iadd_5 = None
    getattr_l__self___layer4___0___conv1 = self.getattr_L__self___layer4___0___conv1(getattr_l__self___layer3___1___relu_1)
    getattr_l__self___layer4___0___downsample_0 = self.getattr_L__self___layer4___0___downsample_0(getattr_l__self___layer3___1___relu_1);  getattr_l__self___layer3___1___relu_1 = None
    getattr_l__self___layer4___0___bn1 = self.getattr_L__self___layer4___0___bn1(getattr_l__self___layer4___0___conv1);  getattr_l__self___layer4___0___conv1 = None
    getattr_l__self___layer4___0___downsample_1 = self.getattr_L__self___layer4___0___downsample_1(getattr_l__self___layer4___0___downsample_0);  getattr_l__self___layer4___0___downsample_0 = None
    getattr_l__self___layer4___0___relu = self.getattr_L__self___layer4___0___relu(getattr_l__self___layer4___0___bn1);  getattr_l__self___layer4___0___bn1 = None
    getattr_l__self___layer4___0___conv2 = self.getattr_L__self___layer4___0___conv2(getattr_l__self___layer4___0___relu);  getattr_l__self___layer4___0___relu = None
    getattr_l__self___layer4___0___bn2 = self.getattr_L__self___layer4___0___bn2(getattr_l__self___layer4___0___conv2);  getattr_l__self___layer4___0___conv2 = None
    getattr_l__self___layer4___0___bn2 += getattr_l__self___layer4___0___downsample_1;  iadd_6 = getattr_l__self___layer4___0___bn2;  getattr_l__self___layer4___0___bn2 = getattr_l__self___layer4___0___downsample_1 = None
    getattr_l__self___layer4___0___relu_1 = self.getattr_L__self___layer4___0___relu(iadd_6);  iadd_6 = None
    getattr_l__self___layer4___1___conv1 = self.getattr_L__self___layer4___1___conv1(getattr_l__self___layer4___0___relu_1)
    getattr_l__self___layer4___1___bn1 = self.getattr_L__self___layer4___1___bn1(getattr_l__self___layer4___1___conv1);  getattr_l__self___layer4___1___conv1 = None
    getattr_l__self___layer4___1___relu = self.getattr_L__self___layer4___1___relu(getattr_l__self___layer4___1___bn1);  getattr_l__self___layer4___1___bn1 = None
    getattr_l__self___layer4___1___conv2 = self.getattr_L__self___layer4___1___conv2(getattr_l__self___layer4___1___relu);  getattr_l__self___layer4___1___relu = None
    getattr_l__self___layer4___1___bn2 = self.getattr_L__self___layer4___1___bn2(getattr_l__self___layer4___1___conv2);  getattr_l__self___layer4___1___conv2 = None
    getattr_l__self___layer4___1___bn2 += getattr_l__self___layer4___0___relu_1;  iadd_7 = getattr_l__self___layer4___1___bn2;  getattr_l__self___layer4___1___bn2 = getattr_l__self___layer4___0___relu_1 = None
    getattr_l__self___layer4___1___relu_1 = self.getattr_L__self___layer4___1___relu(iadd_7);  iadd_7 = None
    l__self___avgpool = self.L__self___avgpool(getattr_l__self___layer4___1___relu_1);  getattr_l__self___layer4___1___relu_1 = None
    flatten = torch.flatten(l__self___avgpool, 1);  l__self___avgpool = None
    l__self___fc = self.L__self___fc(flatten);  flatten = None
    return l__self___fc

openxla once is also quite longer and looks nicer.. Let me play with create_functionalized_graph a bit.

JackCaoG commented 1 year ago

hmm, doesn't seem like create_functionalized_graph is being called. I was wrong, this function is being called when using aot backend... Trying to see if there is a way to bypass it.

bdhirsh commented 1 year ago

AOTAutograd will do a few things in the inference path:

Are the warnings posted somewhere? I don't see a link to them. On the perf difference though - a bit of a shot in the dark, but I wonder of the copy_() input mutations showing up later after the batchnorm call could cause any perf issues with XLA?

shunting314 commented 1 year ago

@JackCaoG the difference of the graph probably is due to the make_fx call in create_functionalized_graph : link .

JackCaoG commented 1 year ago

warning message should be fixed by https://github.com/pytorch/pytorch/pull/107260. @GleasonK after my pr merged, you can use openxla_eval while I am debugging this aot backend speed regression.

JackCaoG commented 1 year ago

I was able to compare the IR and counter, two things stands on is that

  1. aot backend uses xla::_native_batch_norm_legit instead of xla::native_batch_norm. I think this is fine, I vagully remember I implemented the _native_batch_norm_legit by calling native_batch_norm
  2. There are a lot more clone in the aot_backend. For the resnet 18 aot_backend
    Counter: xla::clone
    Value: 103
    Counter: xla::empty_symint
    Value: 80

@bdhirsh do you where know where are these clone from?

non-aot

Counter: xla::clone
  Value: 1
Counter: xla::empty_symint
  Value: 120
JackCaoG commented 1 year ago

IR for openxla(aot backend)

IR {
  %0 = f32[64,3,7,7]{0,3,2,1} xla::device_data(), xla_shape=f32[64,3,7,7]{0,3,2,1}, device=TPU:0
  %1 = f32[4,3,224,224]{3,2,1,0} xla::device_data(), xla_shape=f32[4,3,224,224]{3,2,1,0}, device=TPU:0
  %2 = f32[4,64,112,112]{3,2,1,0} aten::convolution_overrideable(%1, %0), xla_shape=f32[4,64,112,112]{3,2,1,0}, stride=(2, 2), padding=(3, 3), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=0
  %3 = f32[1000,512]{1,0} xla::device_data(), xla_shape=f32[1000,512]{1,0}, device=TPU:0
  %4 = f32[512,1000]{0,1} aten::permute(%3), xla_shape=f32[512,1000]{0,1}, dims=(1, 0), ROOT=1
  %5 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %6 = f32[0]{0} aten::expand(%5), xla_shape=f32[0]{0}, size=(0), ROOT=2
  %7 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %8 = f32[0]{0} aten::expand(%7), xla_shape=f32[0]{0}, size=(0), ROOT=3
  %9 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %10 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %11 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %12 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %13 = (f32[4,64,112,112]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%2, %12, %11, %10, %9), num_outputs=4, xla_shape=(f32[4,64,112,112]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %14 = f32[4,64,112,112]{3,2,1,0} aten::relu(%13.0), xla_shape=f32[4,64,112,112]{3,2,1,0}, ROOT=4
  %15 = (f32[4,64,56,56]{3,2,1,0}, u32[4,64,56,56]{3,2,1,0}) aten::max_pool2d(%14), num_outputs=2, xla_shape=(f32[4,64,56,56]{3,2,1,0}, u32[4,64,56,56]{3,2,1,0}), spatial_dim_count=2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=0, ROOT=5
  %16 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
  %17 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%15.0, %16), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=6
  %18 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %19 = f32[0]{0} aten::expand(%18), xla_shape=f32[0]{0}, size=(0), ROOT=7
  %20 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %21 = f32[0]{0} aten::expand(%20), xla_shape=f32[0]{0}, size=(0), ROOT=8
  %22 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %23 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %24 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %25 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %26 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%17, %25, %24, %23, %22), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %27 = f32[4,64,56,56]{3,2,1,0} aten::relu(%26.0), xla_shape=f32[4,64,56,56]{3,2,1,0}, ROOT=9
  %28 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
  %29 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%27, %28), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=10
  %30 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %31 = f32[0]{0} aten::expand(%30), xla_shape=f32[0]{0}, size=(0), ROOT=11
  %32 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %33 = f32[0]{0} aten::expand(%32), xla_shape=f32[0]{0}, size=(0), ROOT=12
  %34 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %35 = f32[4,64,56,56]{3,2,1,0} aten::expand(%34), xla_shape=f32[4,64,56,56]{3,2,1,0}, size=(4, 64, 56, 56)
  %36 = f32[4,64,56,56]{3,2,1,0} aten::mul(%15.0, %35), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %37 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %38 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %39 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %40 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %41 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%29, %40, %39, %38, %37), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %42 = f32[4,64,56,56]{3,2,1,0} aten::add(%41.0, %36), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %43 = f32[4,64,56,56]{3,2,1,0} aten::relu(%42), xla_shape=f32[4,64,56,56]{3,2,1,0}, ROOT=13
  %44 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
  %45 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%43, %44), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=14
  %46 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %47 = f32[0]{0} aten::expand(%46), xla_shape=f32[0]{0}, size=(0), ROOT=15
  %48 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %49 = f32[0]{0} aten::expand(%48), xla_shape=f32[0]{0}, size=(0), ROOT=16
  %50 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %51 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %52 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %53 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %54 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%45, %53, %52, %51, %50), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %55 = f32[4,64,56,56]{3,2,1,0} aten::relu(%54.0), xla_shape=f32[4,64,56,56]{3,2,1,0}, ROOT=17
  %56 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
  %57 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%55, %56), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=18
  %58 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %59 = f32[0]{0} aten::expand(%58), xla_shape=f32[0]{0}, size=(0), ROOT=19
  %60 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %61 = f32[0]{0} aten::expand(%60), xla_shape=f32[0]{0}, size=(0), ROOT=20
  %62 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %63 = f32[4,64,56,56]{3,2,1,0} aten::expand(%62), xla_shape=f32[4,64,56,56]{3,2,1,0}, size=(4, 64, 56, 56)
  %64 = f32[4,64,56,56]{3,2,1,0} aten::mul(%43, %63), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %65 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %66 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %67 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %68 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %69 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%57, %68, %67, %66, %65), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %70 = f32[4,64,56,56]{3,2,1,0} aten::add(%69.0, %64), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %71 = f32[4,64,56,56]{3,2,1,0} aten::relu(%70), xla_shape=f32[4,64,56,56]{3,2,1,0}, ROOT=21
  %72 = f32[128,64,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[128,64,3,3]{0,1,3,2}, device=TPU:0
  %73 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%71, %72), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=22
  %74 = f32[128,64,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[128,64,1,1]{0,1,3,2}, device=TPU:0
  %75 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%71, %74), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=23
  %76 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %77 = f32[0]{0} aten::expand(%76), xla_shape=f32[0]{0}, size=(0), ROOT=24
  %78 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %79 = f32[0]{0} aten::expand(%78), xla_shape=f32[0]{0}, size=(0), ROOT=25
  %80 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %81 = f32[0]{0} aten::expand(%80), xla_shape=f32[0]{0}, size=(0), ROOT=26
  %82 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %83 = f32[0]{0} aten::expand(%82), xla_shape=f32[0]{0}, size=(0), ROOT=27
  %84 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %85 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %86 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %87 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %88 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%73, %87, %86, %85, %84), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %89 = f32[4,128,28,28]{3,2,1,0} aten::relu(%88.0), xla_shape=f32[4,128,28,28]{3,2,1,0}, ROOT=28
  %90 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
  %91 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%89, %90), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=29
  %92 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %93 = f32[0]{0} aten::expand(%92), xla_shape=f32[0]{0}, size=(0), ROOT=30
  %94 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %95 = f32[0]{0} aten::expand(%94), xla_shape=f32[0]{0}, size=(0), ROOT=31
  %96 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %97 = f32[4,128,28,28]{3,2,1,0} aten::expand(%96), xla_shape=f32[4,128,28,28]{3,2,1,0}, size=(4, 128, 28, 28)
  %98 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %99 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %100 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %101 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %102 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%75, %101, %100, %99, %98), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %103 = f32[4,128,28,28]{3,2,1,0} aten::mul(%102.0, %97), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %104 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %105 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %106 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %107 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %108 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%91, %107, %106, %105, %104), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %109 = f32[4,128,28,28]{3,2,1,0} aten::add(%108.0, %103), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %110 = f32[4,128,28,28]{3,2,1,0} aten::relu(%109), xla_shape=f32[4,128,28,28]{3,2,1,0}, ROOT=32
  %111 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
  %112 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%110, %111), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=33
  %113 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %114 = f32[0]{0} aten::expand(%113), xla_shape=f32[0]{0}, size=(0), ROOT=34
  %115 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %116 = f32[0]{0} aten::expand(%115), xla_shape=f32[0]{0}, size=(0), ROOT=35
  %117 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %118 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %119 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %120 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %121 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%112, %120, %119, %118, %117), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %122 = f32[4,128,28,28]{3,2,1,0} aten::relu(%121.0), xla_shape=f32[4,128,28,28]{3,2,1,0}, ROOT=36
  %123 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
  %124 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%122, %123), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=37
  %125 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %126 = f32[0]{0} aten::expand(%125), xla_shape=f32[0]{0}, size=(0), ROOT=38
  %127 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %128 = f32[0]{0} aten::expand(%127), xla_shape=f32[0]{0}, size=(0), ROOT=39
  %129 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %130 = f32[4,128,28,28]{3,2,1,0} aten::expand(%129), xla_shape=f32[4,128,28,28]{3,2,1,0}, size=(4, 128, 28, 28)
  %131 = f32[4,128,28,28]{3,2,1,0} aten::mul(%110, %130), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %132 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %133 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %134 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %135 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %136 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%124, %135, %134, %133, %132), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %137 = f32[4,128,28,28]{3,2,1,0} aten::add(%136.0, %131), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %138 = f32[4,128,28,28]{3,2,1,0} aten::relu(%137), xla_shape=f32[4,128,28,28]{3,2,1,0}, ROOT=40
  %139 = f32[256,128,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[256,128,3,3]{0,1,3,2}, device=TPU:0
  %140 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%138, %139), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=41
  %141 = f32[256,128,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[256,128,1,1]{0,1,3,2}, device=TPU:0
  %142 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%138, %141), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=42
  %143 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %144 = f32[0]{0} aten::expand(%143), xla_shape=f32[0]{0}, size=(0), ROOT=43
  %145 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %146 = f32[0]{0} aten::expand(%145), xla_shape=f32[0]{0}, size=(0), ROOT=44
  %147 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %148 = f32[0]{0} aten::expand(%147), xla_shape=f32[0]{0}, size=(0), ROOT=45
  %149 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %150 = f32[0]{0} aten::expand(%149), xla_shape=f32[0]{0}, size=(0), ROOT=46
  %151 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %152 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %153 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %154 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %155 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%140, %154, %153, %152, %151), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %156 = f32[4,256,14,14]{3,2,1,0} aten::relu(%155.0), xla_shape=f32[4,256,14,14]{3,2,1,0}, ROOT=47
  %157 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
  %158 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%156, %157), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=48
  %159 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %160 = f32[0]{0} aten::expand(%159), xla_shape=f32[0]{0}, size=(0), ROOT=49
  %161 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %162 = f32[0]{0} aten::expand(%161), xla_shape=f32[0]{0}, size=(0), ROOT=50
  %163 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %164 = f32[4,256,14,14]{3,2,1,0} aten::expand(%163), xla_shape=f32[4,256,14,14]{3,2,1,0}, size=(4, 256, 14, 14)
  %165 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %166 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %167 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %168 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %169 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%142, %168, %167, %166, %165), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %170 = f32[4,256,14,14]{3,2,1,0} aten::mul(%169.0, %164), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %171 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %172 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %173 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %174 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %175 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%158, %174, %173, %172, %171), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %176 = f32[4,256,14,14]{3,2,1,0} aten::add(%175.0, %170), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %177 = f32[4,256,14,14]{3,2,1,0} aten::relu(%176), xla_shape=f32[4,256,14,14]{3,2,1,0}, ROOT=51
  %178 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
  %179 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%177, %178), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=52
  %180 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %181 = f32[0]{0} aten::expand(%180), xla_shape=f32[0]{0}, size=(0), ROOT=53
  %182 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %183 = f32[0]{0} aten::expand(%182), xla_shape=f32[0]{0}, size=(0), ROOT=54
  %184 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %185 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %186 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %187 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %188 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%179, %187, %186, %185, %184), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %189 = f32[4,256,14,14]{3,2,1,0} aten::relu(%188.0), xla_shape=f32[4,256,14,14]{3,2,1,0}, ROOT=55
  %190 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
  %191 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%189, %190), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=56
  %192 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %193 = f32[0]{0} aten::expand(%192), xla_shape=f32[0]{0}, size=(0), ROOT=57
  %194 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %195 = f32[0]{0} aten::expand(%194), xla_shape=f32[0]{0}, size=(0), ROOT=58
  %196 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %197 = f32[4,256,14,14]{3,2,1,0} aten::expand(%196), xla_shape=f32[4,256,14,14]{3,2,1,0}, size=(4, 256, 14, 14)
  %198 = f32[4,256,14,14]{3,2,1,0} aten::mul(%177, %197), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %199 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %200 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %201 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %202 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %203 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%191, %202, %201, %200, %199), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %204 = f32[4,256,14,14]{3,2,1,0} aten::add(%203.0, %198), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %205 = f32[4,256,14,14]{3,2,1,0} aten::relu(%204), xla_shape=f32[4,256,14,14]{3,2,1,0}, ROOT=59
  %206 = f32[512,256,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[512,256,3,3]{0,1,3,2}, device=TPU:0
  %207 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%205, %206), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=60
  %208 = f32[512,256,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[512,256,1,1]{0,1,3,2}, device=TPU:0
  %209 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%205, %208), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=61
  %210 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %211 = f32[0]{0} aten::expand(%210), xla_shape=f32[0]{0}, size=(0), ROOT=62
  %212 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %213 = f32[0]{0} aten::expand(%212), xla_shape=f32[0]{0}, size=(0), ROOT=63
  %214 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %215 = f32[0]{0} aten::expand(%214), xla_shape=f32[0]{0}, size=(0), ROOT=64
  %216 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %217 = f32[0]{0} aten::expand(%216), xla_shape=f32[0]{0}, size=(0), ROOT=65
  %218 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %219 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %220 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %221 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %222 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%207, %221, %220, %219, %218), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %223 = f32[4,512,7,7]{3,2,1,0} aten::relu(%222.0), xla_shape=f32[4,512,7,7]{3,2,1,0}, ROOT=66
  %224 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
  %225 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%223, %224), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=67
  %226 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %227 = f32[0]{0} aten::expand(%226), xla_shape=f32[0]{0}, size=(0), ROOT=68
  %228 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %229 = f32[0]{0} aten::expand(%228), xla_shape=f32[0]{0}, size=(0), ROOT=69
  %230 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %231 = f32[4,512,7,7]{3,2,1,0} aten::expand(%230), xla_shape=f32[4,512,7,7]{3,2,1,0}, size=(4, 512, 7, 7)
  %232 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %233 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %234 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %235 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %236 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%209, %235, %234, %233, %232), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %237 = f32[4,512,7,7]{3,2,1,0} aten::mul(%236.0, %231), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %238 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %239 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %240 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %241 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %242 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%225, %241, %240, %239, %238), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %243 = f32[4,512,7,7]{3,2,1,0} aten::add(%242.0, %237), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %244 = f32[4,512,7,7]{3,2,1,0} aten::relu(%243), xla_shape=f32[4,512,7,7]{3,2,1,0}, ROOT=70
  %245 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
  %246 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%244, %245), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=71
  %247 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %248 = f32[0]{0} aten::expand(%247), xla_shape=f32[0]{0}, size=(0), ROOT=72
  %249 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %250 = f32[0]{0} aten::expand(%249), xla_shape=f32[0]{0}, size=(0), ROOT=73
  %251 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %252 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %253 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %254 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %255 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%246, %254, %253, %252, %251), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %256 = f32[4,512,7,7]{3,2,1,0} aten::relu(%255.0), xla_shape=f32[4,512,7,7]{3,2,1,0}, ROOT=74
  %257 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
  %258 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%256, %257), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=75
  %259 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %260 = f32[0]{0} aten::expand(%259), xla_shape=f32[0]{0}, size=(0), ROOT=76
  %261 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %262 = f32[0]{0} aten::expand(%261), xla_shape=f32[0]{0}, size=(0), ROOT=77
  %263 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %264 = f32[4,512,7,7]{3,2,1,0} aten::expand(%263), xla_shape=f32[4,512,7,7]{3,2,1,0}, size=(4, 512, 7, 7)
  %265 = f32[4,512,7,7]{3,2,1,0} aten::mul(%244, %264), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %266 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %267 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %268 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %269 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %270 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%258, %269, %268, %267, %266), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %271 = f32[4,512,7,7]{3,2,1,0} aten::add(%270.0, %265), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %272 = f32[4,512,7,7]{3,2,1,0} aten::relu(%271), xla_shape=f32[4,512,7,7]{3,2,1,0}, ROOT=78
  %273 = f32[4,512,1,1]{3,2,1,0} aten::mean(%272), xla_shape=f32[4,512,1,1]{3,2,1,0}, dimensions=(3, 2), keep_reduced_dimensions=1, dtype=-1
  %274 = f32[4,512]{1,0} aten::view(%273), xla_shape=f32[4,512]{1,0}, output_size=(4, 512), ROOT=79
  %275 = f32[1000]{0} xla::device_data(), xla_shape=f32[1000]{0}, device=TPU:0
  %276 = f32[4,1000]{1,0} aten::addmm(%274, %4, %275), xla_shape=f32[4,1000]{1,0}, ROOT=80
}

non-aot

IR {
  %0 = f32[1000]{0} xla::device_data(), xla_shape=f32[1000]{0}, device=TPU:0
  %1 = f32[1000,512]{1,0} xla::device_data(), xla_shape=f32[1000,512]{1,0}, device=TPU:0
  %2 = f32[512,1000]{0,1} aten::permute(%1), xla_shape=f32[512,1000]{0,1}, dims=(1, 0)
  %3 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %4 = f32[4,512,7,7]{3,2,1,0} aten::expand(%3), xla_shape=f32[4,512,7,7]{3,2,1,0}, size=(4, 512, 7, 7)
  %5 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %6 = f32[4,512,7,7]{3,2,1,0} aten::expand(%5), xla_shape=f32[4,512,7,7]{3,2,1,0}, size=(4, 512, 7, 7)
  %7 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %8 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %9 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %10 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %11 = f32[512,256,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[512,256,1,1]{0,1,3,2}, device=TPU:0
  %12 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %13 = f32[4,256,14,14]{3,2,1,0} aten::expand(%12), xla_shape=f32[4,256,14,14]{3,2,1,0}, size=(4, 256, 14, 14)
  %14 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %15 = f32[4,256,14,14]{3,2,1,0} aten::expand(%14), xla_shape=f32[4,256,14,14]{3,2,1,0}, size=(4, 256, 14, 14)
  %16 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %17 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %18 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %19 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %20 = f32[256,128,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[256,128,1,1]{0,1,3,2}, device=TPU:0
  %21 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %22 = f32[4,128,28,28]{3,2,1,0} aten::expand(%21), xla_shape=f32[4,128,28,28]{3,2,1,0}, size=(4, 128, 28, 28)
  %23 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %24 = f32[4,128,28,28]{3,2,1,0} aten::expand(%23), xla_shape=f32[4,128,28,28]{3,2,1,0}, size=(4, 128, 28, 28)
  %25 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %26 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %27 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %28 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %29 = f32[128,64,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[128,64,1,1]{0,1,3,2}, device=TPU:0
  %30 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %31 = f32[4,64,56,56]{3,2,1,0} aten::expand(%30), xla_shape=f32[4,64,56,56]{3,2,1,0}, size=(4, 64, 56, 56)
  %32 = f32[] prim::Constant(), xla_shape=f32[], value=1
  %33 = f32[4,64,56,56]{3,2,1,0} aten::expand(%32), xla_shape=f32[4,64,56,56]{3,2,1,0}, size=(4, 64, 56, 56)
  %34 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %35 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %36 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %37 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %38 = f32[64,3,7,7]{0,3,2,1} xla::device_data(), xla_shape=f32[64,3,7,7]{0,3,2,1}, device=TPU:0
  %39 = f32[4,3,224,224]{3,2,1,0} xla::device_data(), xla_shape=f32[4,3,224,224]{3,2,1,0}, device=TPU:0
  %40 = f32[4,64,112,112]{3,2,1,0} aten::convolution_overrideable(%39, %38), xla_shape=f32[4,64,112,112]{3,2,1,0}, stride=(2, 2), padding=(3, 3), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %41 = (f32[4,64,112,112]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%40, %37, %36, %35, %34), num_outputs=4, xla_shape=(f32[4,64,112,112]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %42 = f32[4,64,112,112]{3,2,1,0} aten::relu(%41.0), xla_shape=f32[4,64,112,112]{3,2,1,0}
  %43 = (f32[4,64,56,56]{3,2,1,0}, u32[4,64,56,56]{3,2,1,0}) aten::max_pool2d(%42), num_outputs=2, xla_shape=(f32[4,64,56,56]{3,2,1,0}, u32[4,64,56,56]{3,2,1,0}), spatial_dim_count=2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=0
  %44 = f32[4,64,56,56]{3,2,1,0} aten::mul(%43.0, %33), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %45 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %46 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %47 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %48 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %49 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
  %50 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %51 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %52 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %53 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %54 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
  %55 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%43.0, %54), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %56 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%55, %53, %52, %51, %50), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %57 = f32[4,64,56,56]{3,2,1,0} aten::relu(%56.0), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %58 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%57, %49), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %59 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%58, %48, %47, %46, %45), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %60 = f32[4,64,56,56]{3,2,1,0} aten::add(%59.0, %44), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %61 = f32[4,64,56,56]{3,2,1,0} aten::relu(%60), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %62 = f32[4,64,56,56]{3,2,1,0} aten::mul(%61, %31), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %63 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %64 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %65 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %66 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %67 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
  %68 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %69 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %70 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %71 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
  %72 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
  %73 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%61, %72), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %74 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%73, %71, %70, %69, %68), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %75 = f32[4,64,56,56]{3,2,1,0} aten::relu(%74.0), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %76 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%75, %67), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %77 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%76, %66, %65, %64, %63), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
  %78 = f32[4,64,56,56]{3,2,1,0} aten::add(%77.0, %62), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %79 = f32[4,64,56,56]{3,2,1,0} aten::relu(%78), xla_shape=f32[4,64,56,56]{3,2,1,0}
  %80 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%79, %29), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %81 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%80, %28, %27, %26, %25), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %82 = f32[4,128,28,28]{3,2,1,0} aten::mul(%81.0, %24), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %83 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %84 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %85 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %86 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %87 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
  %88 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %89 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %90 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %91 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %92 = f32[128,64,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[128,64,3,3]{0,1,3,2}, device=TPU:0
  %93 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%79, %92), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %94 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%93, %91, %90, %89, %88), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %95 = f32[4,128,28,28]{3,2,1,0} aten::relu(%94.0), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %96 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%95, %87), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %97 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%96, %86, %85, %84, %83), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %98 = f32[4,128,28,28]{3,2,1,0} aten::add(%97.0, %82), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %99 = f32[4,128,28,28]{3,2,1,0} aten::relu(%98), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %100 = f32[4,128,28,28]{3,2,1,0} aten::mul(%99, %22), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %101 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %102 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %103 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %104 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %105 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
  %106 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %107 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %108 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %109 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
  %110 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
  %111 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%99, %110), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %112 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%111, %109, %108, %107, %106), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %113 = f32[4,128,28,28]{3,2,1,0} aten::relu(%112.0), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %114 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%113, %105), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %115 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%114, %104, %103, %102, %101), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
  %116 = f32[4,128,28,28]{3,2,1,0} aten::add(%115.0, %100), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %117 = f32[4,128,28,28]{3,2,1,0} aten::relu(%116), xla_shape=f32[4,128,28,28]{3,2,1,0}
  %118 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%117, %20), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %119 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%118, %19, %18, %17, %16), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %120 = f32[4,256,14,14]{3,2,1,0} aten::mul(%119.0, %15), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %121 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %122 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %123 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %124 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %125 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
  %126 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %127 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %128 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %129 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %130 = f32[256,128,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[256,128,3,3]{0,1,3,2}, device=TPU:0
  %131 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%117, %130), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %132 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%131, %129, %128, %127, %126), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %133 = f32[4,256,14,14]{3,2,1,0} aten::relu(%132.0), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %134 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%133, %125), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %135 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%134, %124, %123, %122, %121), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %136 = f32[4,256,14,14]{3,2,1,0} aten::add(%135.0, %120), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %137 = f32[4,256,14,14]{3,2,1,0} aten::relu(%136), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %138 = f32[4,256,14,14]{3,2,1,0} aten::mul(%137, %13), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %139 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %140 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %141 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %142 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %143 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
  %144 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %145 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %146 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %147 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
  %148 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
  %149 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%137, %148), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %150 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%149, %147, %146, %145, %144), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %151 = f32[4,256,14,14]{3,2,1,0} aten::relu(%150.0), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %152 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%151, %143), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %153 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%152, %142, %141, %140, %139), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
  %154 = f32[4,256,14,14]{3,2,1,0} aten::add(%153.0, %138), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %155 = f32[4,256,14,14]{3,2,1,0} aten::relu(%154), xla_shape=f32[4,256,14,14]{3,2,1,0}
  %156 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%155, %11), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %157 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%156, %10, %9, %8, %7), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %158 = f32[4,512,7,7]{3,2,1,0} aten::mul(%157.0, %6), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %159 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %160 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %161 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %162 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %163 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
  %164 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %165 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %166 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %167 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %168 = f32[512,256,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[512,256,3,3]{0,1,3,2}, device=TPU:0
  %169 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%155, %168), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %170 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%169, %167, %166, %165, %164), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %171 = f32[4,512,7,7]{3,2,1,0} aten::relu(%170.0), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %172 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%171, %163), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %173 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%172, %162, %161, %160, %159), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %174 = f32[4,512,7,7]{3,2,1,0} aten::add(%173.0, %158), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %175 = f32[4,512,7,7]{3,2,1,0} aten::relu(%174), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %176 = f32[4,512,7,7]{3,2,1,0} aten::mul(%175, %4), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %177 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %178 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %179 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %180 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %181 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
  %182 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %183 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %184 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %185 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
  %186 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
  %187 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%175, %186), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %188 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%187, %185, %184, %183, %182), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %189 = f32[4,512,7,7]{3,2,1,0} aten::relu(%188.0), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %190 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%189, %181), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %191 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%190, %180, %179, %178, %177), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
  %192 = f32[4,512,7,7]{3,2,1,0} aten::add(%191.0, %176), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %193 = f32[4,512,7,7]{3,2,1,0} aten::relu(%192), xla_shape=f32[4,512,7,7]{3,2,1,0}
  %194 = f32[4,512,1,1]{3,2,1,0} aten::mean(%193), xla_shape=f32[4,512,1,1]{3,2,1,0}, dimensions=(3, 2), keep_reduced_dimensions=1, dtype=-1
  %195 = f32[4,512]{1,0} aten::view(%194), xla_shape=f32[4,512]{1,0}, output_size=(4, 512)
  %196 = f32[4,1000]{1,0} aten::addmm(%195, %2, %0), xla_shape=f32[4,1000]{1,0}, ROOT=0
}

another thing I noticed is that there are a lot of (40)

  %18 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %19 = f32[0]{0} aten::expand(%18), xla_shape=f32[0]{0}, size=(0), ROOT=7
  %20 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %21 = f32[0]{0} aten::expand(%20), xla_shape=f32[0]{0}, size=(0), ROOT=8

in the IR and they are not used by any other op, it is just expanding a constant to size [1]..

JackCaoG commented 1 year ago

yea I can confirm that in https://github.com/pytorch/xla/blob/941a4c489add37db9d33d8f9dc3f263a47c3feed/torch_xla/core/dynamo_bridge.py#L240

where we execute the fx graph passed down by dynamo, aot backend has 81 result while non-aot backend has 4 result. Not sure what are these additional results that aot backend requires.

bdhirsh commented 1 year ago

hmm @JackCaoG is it possible to figure out which aten ops in the graph the extra xla::clone calls are coming from? 103 seems like a lot. Going off of my batch norm theory, I would have only expected an extra at::copy_() for every running_mean and running_var buffer, and I'm pretty sure there aren't 100 of then.

another thing I noticed is that there are a lot of (40)

  %18 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %19 = f32[0]{0} aten::expand(%18), xla_shape=f32[0]{0}, size=(0), ROOT=7
  %20 = f32[] prim::Constant(), xla_shape=f32[], value=0
  %21 = f32[0]{0} aten::expand(%20), xla_shape=f32[0]{0}, size=(0), ROOT=8

hmm... that's creating a bunch of zero-sized tensors. It's not clear to me at all where that would be coming from through AOTAutograd. Is it possible to figure out which aten op LazyTensor dispatched on, that eventually desugared into those ops showing up in the graph?

JackCaoG commented 1 year ago

yea, let me do a debug build and try to figure out what happened.

JackCaoG commented 1 year ago

ok I think I know what are those tensors. It seems like when doing the fwd call of the resnet18, there is one real output, but aot autograd need to save a bunch of stuff for the bwd

(Pdb) len(tensors_saved_for_backwards)
161
(Pdb) len(fw_outs)
162

logic is in https://github.com/pytorch/pytorch/blob/0434a2c7c8b7a5ee233c0c670b772eac25c27a2d/torch/_functorch/aot_autograd.py#L3032-L3034

that being said I still don't know why we return a bunch of expand value..

JackCaoG commented 1 year ago

@bdhirsh I was able to confirm that the expand of a 0 size tensor are from empty_symint and they are being captured as tensors_saved_for_backwards for aot-autograd.

bdhirsh commented 1 year ago

@JackCaoG AOTAutograd tries to figure out if you're "doing inference", and if it detects that, then it will run an inference code path (that should involve not saving any tensors for backward). The inference check lives here. As long as your code is run under a no_grad(), then AOTAutograd should compile an inference-only graph.

Can we test that out in the repro?

JackCaoG commented 1 year ago

with torch.no_grad(): solved my issue, the number of output becomes 4 again and HLO become much shorter. @GleasonK I don't know if there is an easy way to model inside no_grad

@bdhirsh what I did was wrapping everything

    with torch.no_grad():        
      resnet18 = torchvision.models.resnet18()
      resnet18.eval()
      xla_resnet18 = torchvision.models.resnet18()
      xla_resnet18.load_state_dict(resnet18.state_dict())
      xla_resnet18.to(device)
      xla_resnet18.eval()
      # materalize the fake data for test purpose
      xm.mark_step()
      xm.wait_device_ops()
      met.clear_all()
      for data, _ in loader:
        dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')

do we only need to wrap it on execution (dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')),

init

      resnet18 = torchvision.models.resnet18()
      resnet18.eval()
      xla_resnet18 = torchvision.models.resnet18()
      xla_resnet18.load_state_dict(resnet18.state_dict())
      xla_resnet18.to(device)

or both?

JackCaoG commented 1 year ago

I was able to confirm just

      dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla') 
      with torch.no_grad(): 
        output = dynamo_resnet18(data)

works

JackCaoG commented 1 year ago

@GleasonK Let me know if you can verified that regression is gone, then we can close this issue

GleasonK commented 1 year ago

Won't be able to verify for awhile (OOO). I'm confident this should work though. Happy to close and re-open if I find issues in the future. Thanks for the quick investigation!

Also think it's worth more discussion on if this should be the recommended solution to the problem. Doesn't have to happen in this ticket though

JackCaoG commented 1 year ago

sg, I will close this issue for now.

JackCaoG commented 1 year ago

actually I found even with torch.no_grad non-aot-backend is still faster.. let me look into why..

JackCaoG commented 1 year ago

I was looking at llama profile for openxla and openxla_eval, I found that openxla triggered a additional compuation for mark_step what it looks like is it is trying to extract a value of a s64 integer.. looking.

JackCaoG commented 1 year ago

looking at the IR.. hmm this might be a llama thing.. it seems to be a bool.

[ScheduleSyncTensorsGraph]
TensorsGraphInfo:
  generate (/src/llama/llama/generation.py:328)
  decorate_context (/src/pytorch/torch/utils/_contextlib.py:115)
  text_completion (/src/llama/llama/generation.py:367)
  main (example_text_completion.py:70)
  mp_main (example_text_completion.py:124)
  _CallAndUpdateTrace (/usr/local/lib/python3.8/site-packages/fire/core.py:691)
  _Fire (/usr/local/lib/python3.8/site-packages/fire/core.py:475)
  Fire (/usr/local/lib/python3.8/site-packages/fire/core.py:141)
  <module> (example_text_completion.py:130)

Hashes: (19aedcce2035cd2efc8e53e8a3a2b99b)

## BEGIN_GRAPH
IR {
  %0 = pred[1]{0} xla::device_data(), xla_shape=pred[1]{0}, device=SPMD:0
  %1 = pred[1]{0} xla::generic_slice(%0), xla_shape=pred[1]{0}, base_indices=(0), sizes=(1)
  %2 = pred[] aten::view(%1), xla_shape=pred[], output_size=(), ROOT=0
}
JackCaoG commented 1 year ago

ok I found the problem.. but I am still trying to figure out what happened here

the issue is caused by

                if all(eos_reached):
                    break

and eos_reached is a tensor with

IR {
  %0 = pred[1]{0} xla::device_data(), xla_shape=pred[1]{0}, device=SPMD:0, ROOT=0
}
(Pdb) eos_reached.size()
torch.Size([1])
(Pdb) eos_reached.dtype
torch.bool

what I didn't expected is that all(eos_reached) actually triggered an execution.. What I don't understand is how is using aot_backend of dynamo has anything to do with this extra execution...

JackCaoG commented 1 year ago

ok nvm, the real issue is from somewhere else

[ScheduleSyncTensorsGraph]
TensorsGraphInfo:
  mark_step (/src/pytorch/xla/torch_xla/core/xla_model.py:815)
  optimized_mod (/src/pytorch/xla/torch_xla/core/dynamo_bridge.py:315)
  forward (<eval_with_key>.11:5)
  _call_impl (/src/pytorch/torch/nn/modules/module.py:1527)
  _wrapped_call_impl (/src/pytorch/torch/nn/modules/module.py:1518)
  __call__ (/src/pytorch/torch/fx/graph_module.py:274)
  call_wrapped (/src/pytorch/torch/fx/graph_module.py:678)
  fwd (/src/pytorch/torch/_dynamo/backends/torchxla.py:46)
  g (/src/pytorch/torch/_functorch/aot_autograd.py:1483)
  rng_functionalization_wrapper (/src/pytorch/torch/_functorch/aot_autograd.py:1595)
  call_func_with_args (/src/pytorch/torch/_functorch/aot_autograd.py:1507)
  runtime_wrapper (/src/pytorch/torch/_functorch/aot_autograd.py:2441)
  g (/src/pytorch/torch/_functorch/aot_autograd.py:1483)
  forward (/src/pytorch/torch/_functorch/aot_autograd.py:3810)
  inner (/src/pytorch/torch/_dynamo/external_utils.py:17)
  _fn (/src/pytorch/torch/_dynamo/eval_frame.py:321)
  _generate_one_token (/src/llama/llama/generation.py:197)
  _fn (/src/pytorch/torch/_dynamo/eval_frame.py:321)
  generate (/src/llama/llama/generation.py:319)
  decorate_context (/src/pytorch/torch/utils/_contextlib.py:115)
  text_completion (/src/llama/llama/generation.py:373)
  main (example_text_completion.py:70)
  mp_main (example_text_completion.py:124)
  _CallAndUpdateTrace (/usr/local/lib/python3.8/site-packages/fire/core.py:691)
  _Fire (/usr/local/lib/python3.8/site-packages/fire/core.py:475)
  Fire (/usr/local/lib/python3.8/site-packages/fire/core.py:141)
  <module> (example_text_completion.py:130)

Hashes: (8cd46c5be00db4f982f612eb9ccdddc3)

## BEGIN_GRAPH
IR {
  %0 = s64[] xla::device_data(), xla_shape=s64[], device=SPMD:0
  %1 = s64[1]{0} aten::as_strided(%0), xla_shape=s64[1]{0}, size=(1), stride=(1), storage_offset=0, ROOT=0
}
JackCaoG commented 1 year ago

Ah ok found it, the additional execution is caused by

        aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset)

in https://github.com/pytorch/pytorch/blob/35b2b3ee476513fa8c1d19fafbf64538edafc3ea/torch/_functorch/aot_autograd.py#L668-L669

HloModule IrToHlo.4, entry_computation_layout={(s64[])->(s64[1]{0})}

ENTRY %IrToHlo.4 (p0.1: s64[]) -> (s64[1]) {
  %p0.1 = s64[] parameter(0), sharding={replicated}, metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<listcomp>@dynamo_bridge.py" source_line=409}
  %reshape.2 = s64[1]{0} reshape(s64[] %p0.1), metadata={op_type="aten__as_strided" op_name="aten__as_strided" source_file="gen_alias_from_base@aot_autograd.py" source_line=669}
  ROOT %tuple.3 = (s64[1]{0}) tuple(s64[1]{0} %reshape.2)
}

@bdhirsh these as_strided call are problematic for xla.. I guess both for inference and training. Do you know what are they for?

JackCaoG commented 1 year ago

The fundermental issue here is pytorch assume they can modify the tensors between graph executions but for torch_xla, any small modification result in to graph execution(since we can't just inplace update anything in XLA device, we always execute a new computation and assign the value to the original C++ tensor).

bdhirsh commented 1 year ago

So, there are a handful of (hopefully rare) places where as_strided() is called in AOTAutograd.

In the limit, ideally we'd probably want to replace every as_strided() call with a corresponding chain of "simple" view ops that can get the same behavior. But this will be a decent chunk of work, and I'm not sure that anyone will have time to try to do this any time soon. In this particular case, it looks like AOTAutograd made a best-effort attempt to re-use autograd's existing view-replay mechanism (code), but wasn't able to (probably because this condition failed?)

@JackCaoG I'm wondering if as an alternative to unblock - if you inspect that as_strided() call and look at the sizes/strides passed in, they're probably pretty simple. Could XLA update their lowering for as_strided() to detect when the size/stride passed in corresponds to a "simple" view op (e.g. the size is the same as the input, and the strides are just contiguous strides), and manually call the view lowering?

JackCaoG commented 1 year ago

Thanks Brain. The issue is not necessary that as_strided call is complex, but this is an op not being captured in the fx graph(it directly called on the input tensor). This forced us to execute an additioanl graph to materalize the input before running the dynamo graph which slows down the inference,

bdhirsh commented 1 year ago

Ah I see. Unfortunately, AOTAutograd can't make any guarantees that every op will be part of the FX graph that we send to fw_compiler/bw_compiler - to ensure that all edge cases can be handled properly, AOTAutograd will sometimes generate a "runtime epilogue" that gets run after the compiled graph.

There are some details here, but one example is that if an output aliases an input, we can't just blindly put it in the graph - in order to satisfy the autograd engine, we need to re-run the view outside of the custom autograd.Function object that we generated.

Just brainstorming, but I wonder - If XLA sees that the "second" graph that they trace out only consists of view ops (as it sounds like is the case in this example) - would it be possible to skip all the expensive compilation steps and just immediately return the correct view? Since no real compute is happening if all you have is a bunch of view ops in the graph.

ysiraichi commented 6 months ago

@JackCaoG @bdhirsh I actually have thought of something that could be a solution for this issue.

After https://github.com/pytorch/pytorch/pull/121007 got merged, we actually have access to the chain of view functions that should be applied to the input. PyTorch currently applies that to the input. As @JackCaoG mentioned, that would still cause graph compilation+execution.

You are probably aware that PyTorch/XLA doesn't actually have the notion of tensor views. It just materializes an entire new tensor. In order to support it, FunctionalTensorWrapper was created. Currently, we apply the view operations on the input, so as to get an output FunctionalTensorWrapper that has input (or whatever is the base of input) as its base.

Here's my idea: we don't actually need to re-run the view operations, again. We have everything we need for creating a new FunctionalTensorWrapper (i.e. XLA tensor) while maintaining the aliasing relation I mentioned above. Let i be our input, out be the concrete output with incorrect aliasing information, and fout be the FunctionalTensorWrapper for the output we got in the functionalization step:

FunctionalTensorWrapper(
    base=i.base,
    view_metas=i.view_metas + fout.view_metas,
    value=out.value,
)

Let me know what you think.

JackCaoG commented 6 months ago

I guess I am not the expert of the FunctionalTensorWrapper. Requirement on pytorch/xla side is that input data is a DeviceData not some kind of pending executing(view ops). How is your proposal going to accomplish that? I assume the intention of the logic in aot-autograd is to compute the alias_out which is a of different shape of the alias_base, how can we get the alias_out as a data without executing any computation on the alias_base?

ysiraichi commented 6 months ago

Currently, this is the signature of the function (source):

def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires_grad, target_functional_tensor=None):
  1. aliased_base_tensor is our input
  2. target_meta_tensor is the concrete output that came out of OpenXLA graph execution
  3. target_functional_tensor is the FunctionalTensorWrapper we got in the functionalization step

To answer your questions:

Requirement on pytorch/xla side is that input data is a DeviceData not some kind of pending executing(view ops). How is your proposal going to accomplish that?

My proposal comes from the observation that we've already got the output we want. So, no need to execute any view operations.

how can we get the alias_out as a data without executing any computation on the alias_base?

Because we have already run the computation. And the result of that is (2). What PyTorch is doing in this function is trying to reconstruct the aliasing relation between the actual output and (1).


Here's an example: suppose we are compiling the following function:

def foo(inp):
    return inp.expand(2, -1)

Why is view reconstruction needed? Now, assume that inp is a tensor of shape [10]. After going through the dynamo bridge, we would have a new tensor of shape [2, 10]. However, without the view reconstruction, that output would have no aliasing relations with inp. That's because the compiled graph doesn't go through the FunctionalTensorWrapper layer. Thus, losing tensor view information. For example, we would expect that an in-place operation on the output would also change inp. Without the view reconstruction procedure, that's not the case.

Since we know that the output is a view of the input we can just re-run the view operations we recorded (outside the compiled graph) on the functionalization step onto inp. That will give us the correct aliased output.

Why don't we need to re-run the view operations? Because PyTorch/XLA does not have the notion of tensor views. It only has disjoint tensors wrapped with FunctionalTensorWrapper, which emulates aliasing. Therefore, we don't actually need to reconstruct the view. We need to wrap the output we got from the compiled graph into a FunctionalTensorWrapper, so that its base is that of inp. That would be us saying: "I have this completely unrelated tensor (output), which should be actually be a view of inp".

ysiraichi commented 6 months ago

@bdhirsh What do you think about this proposal?

bdhirsh commented 6 months ago

@ysiraichi I think the main thing I don't get from your proposal is that there's another reason we re-run views outside of the graph:

@torch.compile(backend='aot_eager')
def f(x):
    out = x.view(-1)

x = torch.ones(2, requires_grad=True).clone()
out = f(x)
out.mul_(2)

The above code should work, without raising an error. However, if we were to return the view generated inside of the graph directly back to the user, we would end up with the following:

(1) the compiled graph is wrapped in a giant autograd.Function, and the autograd.Function.forward() invokes the compiled graph and returns its output

(2) The autograd.Function will see that it is returning an output that aliases one of the inputs

(3) Autograd will raise an error if the user later tries to mutate that output. It does this because autograd's handling for aliasing + mutation ordinarily involves view-replay: autograd will regenerate the "mutated" tensor by replaying all of the views in the view chain. But in this case, one of those views came from an autograd.Function, which might have side effects, so autograd bans this behavior. Instead, we regenerate the view outside of the autograd.Function: this way autograd can see that it is a plain view op that creates the view, so it can properly do the same view-replay that would normally run in eager mode.

I guess an alternative is that we don't do the view-replay outside the graph for XLA and we say that we're ok with not supporting this view + mutation case, although it would still feel pretty bad to make the view regeneration logic in AOTAutograd conditional, since that increases the testing surface.

ysiraichi commented 6 months ago

@bdhirsh Oh, I didn't know about that. One question, though: what kind of side effects could mutating an output view tensor have? Is this something that never happens when using AOTAutograd?

ysiraichi commented 6 months ago

@bdhirsh After internal discussion with @JackCaoG, he mentioned that maybe we could do that on contexts where we know we won't use autograd, e.g. no_grad or inference_mode. What do you think?

ysiraichi commented 5 months ago

Oh, I didn't know about that. One question, though: what kind of side effects could mutating an output view tensor have? Is this something that never happens when using AOTAutograd?

@bdhirsh Do you have any thoughts on this?

ysiraichi commented 5 months ago

(2) The autograd.Function will see that it is returning an output that aliases one of the inputs

Will it, though? I'm asking because even though we are returning something computed inside an autograd.Function, it's wrapped in a fresh FunctionalTensorWrapper, created outside the autograd.Function.

What do you think?

bdhirsh commented 5 months ago

@ysiraichi hmm - yeah, I think it would be relatively reasonable to "only replay the view ops outside of the graph in that wrapper" if we are going down the training path of AOTAutograd.

one thing to note - @jamesjwu is doing a lot of refactoring of the runtime wrappers as part of adding a warm cache to AOTAutograd, so there will be some risk of merge conflicts in the next few weeks there.

One question, though: what kind of side effects could mutating an output view tensor have? Is this something that never happens when using AOTAutograd?

It's more just the fact that this is "allowed" today, so if the view op below happens to be compiled, we need it to work:

a = torch.randn(4, requires_grad=True).clone()
a_view = a.view(-1)
print(a_view.grad_fn)
# prints <ViewBackward0 object at 0x7fd2b331be20>
a_view.mul_(2)
print(a_view.grad_fn)
# prints <AsStridedBackward0 object at 0x7fd2b331be20>