Open GleasonK opened 1 year ago
Thanks! I took a quick look at the HLO dump and the openxla
(which is aot_torchxla_trace_once
renamed)'s HLO is longer than torch_xla_trace_once
. @wconstab @shunting314 Do you guys know what's the additional stuff that aot backend do? I thought if it is just inference the fx graph should be the same.
I'm not too familiar with what the aot_ backend does differently. What's the history of that backend (is it based on the dynamo backend shunting wrote or something else?)
open_xla
== aot_torchxla_trace_once
, and the only difference between torchxla_trace_once
and aot_torchxla_trace_once
is the aot_autograd
aot_torchxla_trace_once = aot_autograd(
fw_compiler=torchxla_trace_once,
)
https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/backends/torchxla.py#L49-L51
I guess the question is, will aot_autograd
do anything if it is only the fwd function of the model being called?
Maybe due to functionalization: https://github.com/pytorch/pytorch/blob/main/torch/_functorch/aot_autograd.py#L1534 . It's done even for inference.
I will take a look at fx graph
I dumped the fx graph for resnet 18 using the openxla
(aot) and torchxla_trace_once
. Their FX graph looks quite different
openxla
def forward(self, primals_61, primals_123, primals_1, primals_2, primals_3, primals_63, primals_64, primals_4, primals_5, primals_6, primals_66, primals_67, primals_7, primals_8, primals_9, primals_69, primals_70, primals_10, primals_11, primals_12, primals_72, primals_73, primals_13, primals_14, primals_15, primals_75, primals_76, primals_16, primals_22, primals_17, primals_18, primals_78, primals_79, primals_23, primals_24, primals_84, primals_85, primals_19, primals_20, primals_21, primals_81, primals_82, primals_25, primals_26, primals_27, primals_87, primals_88, primals_28, primals_29, primals_30, primals_90, primals_91, primals_31, primals_37, primals_32, primals_33, primals_93, primals_94, primals_38, primals_39, primals_99, primals_100, primals_34, primals_35, primals_36, primals_96, primals_97, primals_40, primals_41, primals_42, primals_102, primals_103, primals_43, primals_44, primals_45, primals_105, primals_106, primals_46, primals_52, primals_47, primals_48, primals_108, primals_109, primals_53, primals_54, primals_114, primals_115, primals_49, primals_50, primals_51, primals_111, primals_112, primals_55, primals_56, primals_57, primals_117, primals_118, primals_58, primals_59, primals_60, primals_120, primals_121, primals_62):
t = torch.ops.aten.t.default(primals_61); primals_61 = None
convolution = torch.ops.aten.convolution.default(primals_123, primals_1, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1); primals_123 = primals_1 = None
_native_batch_norm_legit_no_training = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution, primals_2, primals_3, primals_63, primals_64, 0.1, 1e-05); primals_2 = primals_3 = primals_63 = primals_64 = None
getitem = _native_batch_norm_legit_no_training[0]
getitem_1 = _native_batch_norm_legit_no_training[1]
getitem_2 = _native_batch_norm_legit_no_training[2]; _native_batch_norm_legit_no_training = None
relu = torch.ops.aten.relu.default(getitem); getitem = None
max_pool2d_forward = torch.ops.xla.max_pool2d_forward.default(relu, [3, 3], [2, 2], [1, 1])
convolution_1 = torch.ops.aten.convolution.default(max_pool2d_forward, primals_4, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_4 = None
_native_batch_norm_legit_no_training_1 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_1, primals_5, primals_6, primals_66, primals_67, 0.1, 1e-05); primals_5 = primals_6 = primals_66 = primals_67 = None
getitem_3 = _native_batch_norm_legit_no_training_1[0]
getitem_4 = _native_batch_norm_legit_no_training_1[1]
getitem_5 = _native_batch_norm_legit_no_training_1[2]; _native_batch_norm_legit_no_training_1 = None
relu_1 = torch.ops.aten.relu.default(getitem_3); getitem_3 = None
convolution_2 = torch.ops.aten.convolution.default(relu_1, primals_7, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_7 = None
_native_batch_norm_legit_no_training_2 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_2, primals_8, primals_9, primals_69, primals_70, 0.1, 1e-05); primals_8 = primals_9 = primals_69 = primals_70 = None
getitem_6 = _native_batch_norm_legit_no_training_2[0]
getitem_7 = _native_batch_norm_legit_no_training_2[1]
getitem_8 = _native_batch_norm_legit_no_training_2[2]; _native_batch_norm_legit_no_training_2 = None
add = torch.ops.aten.add.Tensor(getitem_6, max_pool2d_forward); getitem_6 = None
relu_2 = torch.ops.aten.relu.default(add); add = None
convolution_3 = torch.ops.aten.convolution.default(relu_2, primals_10, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_10 = None
_native_batch_norm_legit_no_training_3 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_3, primals_11, primals_12, primals_72, primals_73, 0.1, 1e-05); primals_11 = primals_12 = primals_72 = primals_73 = None
getitem_9 = _native_batch_norm_legit_no_training_3[0]
getitem_10 = _native_batch_norm_legit_no_training_3[1]
getitem_11 = _native_batch_norm_legit_no_training_3[2]; _native_batch_norm_legit_no_training_3 = None
relu_3 = torch.ops.aten.relu.default(getitem_9); getitem_9 = None
convolution_4 = torch.ops.aten.convolution.default(relu_3, primals_13, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_13 = None
_native_batch_norm_legit_no_training_4 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_4, primals_14, primals_15, primals_75, primals_76, 0.1, 1e-05); primals_14 = primals_15 = primals_75 = primals_76 = None
getitem_12 = _native_batch_norm_legit_no_training_4[0]
getitem_13 = _native_batch_norm_legit_no_training_4[1]
getitem_14 = _native_batch_norm_legit_no_training_4[2]; _native_batch_norm_legit_no_training_4 = None
add_1 = torch.ops.aten.add.Tensor(getitem_12, relu_2); getitem_12 = None
relu_4 = torch.ops.aten.relu.default(add_1); add_1 = None
convolution_5 = torch.ops.aten.convolution.default(relu_4, primals_16, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1); primals_16 = None
convolution_7 = torch.ops.aten.convolution.default(relu_4, primals_22, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1); primals_22 = None
_native_batch_norm_legit_no_training_5 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_5, primals_17, primals_18, primals_78, primals_79, 0.1, 1e-05); primals_17 = primals_18 = primals_78 = primals_79 = None
_native_batch_norm_legit_no_training_7 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_7, primals_23, primals_24, primals_84, primals_85, 0.1, 1e-05); primals_23 = primals_24 = primals_84 = primals_85 = None
getitem_15 = _native_batch_norm_legit_no_training_5[0]
getitem_16 = _native_batch_norm_legit_no_training_5[1]
getitem_17 = _native_batch_norm_legit_no_training_5[2]; _native_batch_norm_legit_no_training_5 = None
getitem_21 = _native_batch_norm_legit_no_training_7[0]
getitem_22 = _native_batch_norm_legit_no_training_7[1]
getitem_23 = _native_batch_norm_legit_no_training_7[2]; _native_batch_norm_legit_no_training_7 = None
relu_5 = torch.ops.aten.relu.default(getitem_15); getitem_15 = None
convolution_6 = torch.ops.aten.convolution.default(relu_5, primals_19, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_19 = None
_native_batch_norm_legit_no_training_6 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_6, primals_20, primals_21, primals_81, primals_82, 0.1, 1e-05); primals_20 = primals_21 = primals_81 = primals_82 = None
getitem_18 = _native_batch_norm_legit_no_training_6[0]
getitem_19 = _native_batch_norm_legit_no_training_6[1]
getitem_20 = _native_batch_norm_legit_no_training_6[2]; _native_batch_norm_legit_no_training_6 = None
add_2 = torch.ops.aten.add.Tensor(getitem_18, getitem_21); getitem_18 = getitem_21 = None
relu_6 = torch.ops.aten.relu.default(add_2); add_2 = None
convolution_8 = torch.ops.aten.convolution.default(relu_6, primals_25, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_25 = None
_native_batch_norm_legit_no_training_8 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_8, primals_26, primals_27, primals_87, primals_88, 0.1, 1e-05); primals_26 = primals_27 = primals_87 = primals_88 = None
getitem_24 = _native_batch_norm_legit_no_training_8[0]
getitem_25 = _native_batch_norm_legit_no_training_8[1]
getitem_26 = _native_batch_norm_legit_no_training_8[2]; _native_batch_norm_legit_no_training_8 = None
relu_7 = torch.ops.aten.relu.default(getitem_24); getitem_24 = None
convolution_9 = torch.ops.aten.convolution.default(relu_7, primals_28, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_28 = None
_native_batch_norm_legit_no_training_9 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_9, primals_29, primals_30, primals_90, primals_91, 0.1, 1e-05); primals_29 = primals_30 = primals_90 = primals_91 = None
getitem_27 = _native_batch_norm_legit_no_training_9[0]
getitem_28 = _native_batch_norm_legit_no_training_9[1]
getitem_29 = _native_batch_norm_legit_no_training_9[2]; _native_batch_norm_legit_no_training_9 = None
add_3 = torch.ops.aten.add.Tensor(getitem_27, relu_6); getitem_27 = None
relu_8 = torch.ops.aten.relu.default(add_3); add_3 = None
convolution_10 = torch.ops.aten.convolution.default(relu_8, primals_31, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1); primals_31 = None
convolution_12 = torch.ops.aten.convolution.default(relu_8, primals_37, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1); primals_37 = None
_native_batch_norm_legit_no_training_10 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_10, primals_32, primals_33, primals_93, primals_94, 0.1, 1e-05); primals_32 = primals_33 = primals_93 = primals_94 = None
_native_batch_norm_legit_no_training_12 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_12, primals_38, primals_39, primals_99, primals_100, 0.1, 1e-05); primals_38 = primals_39 = primals_99 = primals_100 = None
getitem_30 = _native_batch_norm_legit_no_training_10[0]
getitem_31 = _native_batch_norm_legit_no_training_10[1]
getitem_32 = _native_batch_norm_legit_no_training_10[2]; _native_batch_norm_legit_no_training_10 = None
getitem_36 = _native_batch_norm_legit_no_training_12[0]
getitem_37 = _native_batch_norm_legit_no_training_12[1]
getitem_38 = _native_batch_norm_legit_no_training_12[2]; _native_batch_norm_legit_no_training_12 = None
relu_9 = torch.ops.aten.relu.default(getitem_30); getitem_30 = None
convolution_11 = torch.ops.aten.convolution.default(relu_9, primals_34, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_34 = None
_native_batch_norm_legit_no_training_11 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_11, primals_35, primals_36, primals_96, primals_97, 0.1, 1e-05); primals_35 = primals_36 = primals_96 = primals_97 = None
getitem_33 = _native_batch_norm_legit_no_training_11[0]
getitem_34 = _native_batch_norm_legit_no_training_11[1]
getitem_35 = _native_batch_norm_legit_no_training_11[2]; _native_batch_norm_legit_no_training_11 = None
add_4 = torch.ops.aten.add.Tensor(getitem_33, getitem_36); getitem_33 = getitem_36 = None
relu_10 = torch.ops.aten.relu.default(add_4); add_4 = None
convolution_13 = torch.ops.aten.convolution.default(relu_10, primals_40, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_40 = None
_native_batch_norm_legit_no_training_13 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_13, primals_41, primals_42, primals_102, primals_103, 0.1, 1e-05); primals_41 = primals_42 = primals_102 = primals_103 = None
getitem_39 = _native_batch_norm_legit_no_training_13[0]
getitem_40 = _native_batch_norm_legit_no_training_13[1]
getitem_41 = _native_batch_norm_legit_no_training_13[2]; _native_batch_norm_legit_no_training_13 = None
relu_11 = torch.ops.aten.relu.default(getitem_39); getitem_39 = None
convolution_14 = torch.ops.aten.convolution.default(relu_11, primals_43, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_43 = None
_native_batch_norm_legit_no_training_14 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_14, primals_44, primals_45, primals_105, primals_106, 0.1, 1e-05); primals_44 = primals_45 = primals_105 = primals_106 = None
getitem_42 = _native_batch_norm_legit_no_training_14[0]
getitem_43 = _native_batch_norm_legit_no_training_14[1]
getitem_44 = _native_batch_norm_legit_no_training_14[2]; _native_batch_norm_legit_no_training_14 = None
add_5 = torch.ops.aten.add.Tensor(getitem_42, relu_10); getitem_42 = None
relu_12 = torch.ops.aten.relu.default(add_5); add_5 = None
convolution_15 = torch.ops.aten.convolution.default(relu_12, primals_46, None, [2, 2], [1, 1], [1, 1], False, [0, 0], 1); primals_46 = None
convolution_17 = torch.ops.aten.convolution.default(relu_12, primals_52, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1); primals_52 = None
_native_batch_norm_legit_no_training_15 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_15, primals_47, primals_48, primals_108, primals_109, 0.1, 1e-05); primals_47 = primals_48 = primals_108 = primals_109 = None
_native_batch_norm_legit_no_training_17 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_17, primals_53, primals_54, primals_114, primals_115, 0.1, 1e-05); primals_53 = primals_54 = primals_114 = primals_115 = None
getitem_45 = _native_batch_norm_legit_no_training_15[0]
getitem_46 = _native_batch_norm_legit_no_training_15[1]
getitem_47 = _native_batch_norm_legit_no_training_15[2]; _native_batch_norm_legit_no_training_15 = None
getitem_51 = _native_batch_norm_legit_no_training_17[0]
getitem_52 = _native_batch_norm_legit_no_training_17[1]
getitem_53 = _native_batch_norm_legit_no_training_17[2]; _native_batch_norm_legit_no_training_17 = None
relu_13 = torch.ops.aten.relu.default(getitem_45); getitem_45 = None
convolution_16 = torch.ops.aten.convolution.default(relu_13, primals_49, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_49 = None
_native_batch_norm_legit_no_training_16 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_16, primals_50, primals_51, primals_111, primals_112, 0.1, 1e-05); primals_50 = primals_51 = primals_111 = primals_112 = None
getitem_48 = _native_batch_norm_legit_no_training_16[0]
getitem_49 = _native_batch_norm_legit_no_training_16[1]
getitem_50 = _native_batch_norm_legit_no_training_16[2]; _native_batch_norm_legit_no_training_16 = None
add_6 = torch.ops.aten.add.Tensor(getitem_48, getitem_51); getitem_48 = getitem_51 = None
relu_14 = torch.ops.aten.relu.default(add_6); add_6 = None
convolution_18 = torch.ops.aten.convolution.default(relu_14, primals_55, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_55 = None
_native_batch_norm_legit_no_training_18 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_18, primals_56, primals_57, primals_117, primals_118, 0.1, 1e-05); primals_56 = primals_57 = primals_117 = primals_118 = None
getitem_54 = _native_batch_norm_legit_no_training_18[0]
getitem_55 = _native_batch_norm_legit_no_training_18[1]
getitem_56 = _native_batch_norm_legit_no_training_18[2]; _native_batch_norm_legit_no_training_18 = None
relu_15 = torch.ops.aten.relu.default(getitem_54); getitem_54 = None
convolution_19 = torch.ops.aten.convolution.default(relu_15, primals_58, None, [1, 1], [1, 1], [1, 1], False, [0, 0], 1); primals_58 = None
_native_batch_norm_legit_no_training_19 = torch.ops.aten._native_batch_norm_legit_no_training.default(convolution_19, primals_59, primals_60, primals_120, primals_121, 0.1, 1e-05); primals_59 = primals_60 = primals_120 = primals_121 = None
getitem_57 = _native_batch_norm_legit_no_training_19[0]
getitem_58 = _native_batch_norm_legit_no_training_19[1]
getitem_59 = _native_batch_norm_legit_no_training_19[2]; _native_batch_norm_legit_no_training_19 = None
add_7 = torch.ops.aten.add.Tensor(getitem_57, relu_14); getitem_57 = None
relu_16 = torch.ops.aten.relu.default(add_7); add_7 = None
mean = torch.ops.aten.mean.dim(relu_16, [-1, -2], True)
view = torch.ops.aten.view.default(mean, [4, 512]); mean = None
addmm = torch.ops.aten.addmm.default(primals_62, view, t); primals_62 = None
return (t, convolution, getitem_1, getitem_2, relu, max_pool2d_forward, convolution_1, getitem_4, getitem_5, relu_1, convolution_2, getitem_7, getitem_8, relu_2, convolution_3, getitem_10, getitem_11, relu_3, convolution_4, getitem_13, getitem_14, relu_4, convolution_5, convolution_7, getitem_16, getitem_17, getitem_22, getitem_23, relu_5, convolution_6, getitem_19, getitem_20, relu_6, convolution_8, getitem_25, getitem_26, relu_7, convolution_9, getitem_28, getitem_29, relu_8, convolution_10, convolution_12, getitem_31, getitem_32, getitem_37, getitem_38, relu_9, convolution_11, getitem_34, getitem_35, relu_10, convolution_13, getitem_40, getitem_41, relu_11, convolution_14, getitem_43, getitem_44, relu_12, convolution_15, convolution_17, getitem_46, getitem_47, getitem_52, getitem_53, relu_13, convolution_16, getitem_49, getitem_50, relu_14, convolution_18, getitem_55, getitem_56, relu_15, convolution_19, getitem_58, getitem_59, relu_16, view, addmm)
torchxla_trace_once
def forward(self, l_x_ : torch.Tensor):
l__self___conv1 = self.L__self___conv1(l_x_); l_x_ = None
l__self___bn1 = self.L__self___bn1(l__self___conv1); l__self___conv1 = None
l__self___relu = self.L__self___relu(l__self___bn1); l__self___bn1 = None
l__self___maxpool = self.L__self___maxpool(l__self___relu); l__self___relu = None
getattr_l__self___layer1___0___conv1 = self.getattr_L__self___layer1___0___conv1(l__self___maxpool)
getattr_l__self___layer1___0___bn1 = self.getattr_L__self___layer1___0___bn1(getattr_l__self___layer1___0___conv1); getattr_l__self___layer1___0___conv1 = None
getattr_l__self___layer1___0___relu = self.getattr_L__self___layer1___0___relu(getattr_l__self___layer1___0___bn1); getattr_l__self___layer1___0___bn1 = None
getattr_l__self___layer1___0___conv2 = self.getattr_L__self___layer1___0___conv2(getattr_l__self___layer1___0___relu); getattr_l__self___layer1___0___relu = None
getattr_l__self___layer1___0___bn2 = self.getattr_L__self___layer1___0___bn2(getattr_l__self___layer1___0___conv2); getattr_l__self___layer1___0___conv2 = None
getattr_l__self___layer1___0___bn2 += l__self___maxpool; iadd = getattr_l__self___layer1___0___bn2; getattr_l__self___layer1___0___bn2 = l__self___maxpool = None
getattr_l__self___layer1___0___relu_1 = self.getattr_L__self___layer1___0___relu(iadd); iadd = None
getattr_l__self___layer1___1___conv1 = self.getattr_L__self___layer1___1___conv1(getattr_l__self___layer1___0___relu_1)
getattr_l__self___layer1___1___bn1 = self.getattr_L__self___layer1___1___bn1(getattr_l__self___layer1___1___conv1); getattr_l__self___layer1___1___conv1 = None
getattr_l__self___layer1___1___relu = self.getattr_L__self___layer1___1___relu(getattr_l__self___layer1___1___bn1); getattr_l__self___layer1___1___bn1 = None
getattr_l__self___layer1___1___conv2 = self.getattr_L__self___layer1___1___conv2(getattr_l__self___layer1___1___relu); getattr_l__self___layer1___1___relu = None
getattr_l__self___layer1___1___bn2 = self.getattr_L__self___layer1___1___bn2(getattr_l__self___layer1___1___conv2); getattr_l__self___layer1___1___conv2 = None
getattr_l__self___layer1___1___bn2 += getattr_l__self___layer1___0___relu_1; iadd_1 = getattr_l__self___layer1___1___bn2; getattr_l__self___layer1___1___bn2 = getattr_l__self___layer1___0___relu_1 = None
getattr_l__self___layer1___1___relu_1 = self.getattr_L__self___layer1___1___relu(iadd_1); iadd_1 = None
getattr_l__self___layer2___0___conv1 = self.getattr_L__self___layer2___0___conv1(getattr_l__self___layer1___1___relu_1)
getattr_l__self___layer2___0___downsample_0 = self.getattr_L__self___layer2___0___downsample_0(getattr_l__self___layer1___1___relu_1); getattr_l__self___layer1___1___relu_1 = None
getattr_l__self___layer2___0___bn1 = self.getattr_L__self___layer2___0___bn1(getattr_l__self___layer2___0___conv1); getattr_l__self___layer2___0___conv1 = None
getattr_l__self___layer2___0___downsample_1 = self.getattr_L__self___layer2___0___downsample_1(getattr_l__self___layer2___0___downsample_0); getattr_l__self___layer2___0___downsample_0 = None
getattr_l__self___layer2___0___relu = self.getattr_L__self___layer2___0___relu(getattr_l__self___layer2___0___bn1); getattr_l__self___layer2___0___bn1 = None
getattr_l__self___layer2___0___conv2 = self.getattr_L__self___layer2___0___conv2(getattr_l__self___layer2___0___relu); getattr_l__self___layer2___0___relu = None
getattr_l__self___layer2___0___bn2 = self.getattr_L__self___layer2___0___bn2(getattr_l__self___layer2___0___conv2); getattr_l__self___layer2___0___conv2 = None
getattr_l__self___layer2___0___bn2 += getattr_l__self___layer2___0___downsample_1; iadd_2 = getattr_l__self___layer2___0___bn2; getattr_l__self___layer2___0___bn2 = getattr_l__self___layer2___0___downsample_1 = None
getattr_l__self___layer2___0___relu_1 = self.getattr_L__self___layer2___0___relu(iadd_2); iadd_2 = None
getattr_l__self___layer2___1___conv1 = self.getattr_L__self___layer2___1___conv1(getattr_l__self___layer2___0___relu_1)
getattr_l__self___layer2___1___bn1 = self.getattr_L__self___layer2___1___bn1(getattr_l__self___layer2___1___conv1); getattr_l__self___layer2___1___conv1 = None
getattr_l__self___layer2___1___relu = self.getattr_L__self___layer2___1___relu(getattr_l__self___layer2___1___bn1); getattr_l__self___layer2___1___bn1 = None
getattr_l__self___layer2___1___conv2 = self.getattr_L__self___layer2___1___conv2(getattr_l__self___layer2___1___relu); getattr_l__self___layer2___1___relu = None
getattr_l__self___layer2___1___bn2 = self.getattr_L__self___layer2___1___bn2(getattr_l__self___layer2___1___conv2); getattr_l__self___layer2___1___conv2 = None
getattr_l__self___layer2___1___bn2 += getattr_l__self___layer2___0___relu_1; iadd_3 = getattr_l__self___layer2___1___bn2; getattr_l__self___layer2___1___bn2 = getattr_l__self___layer2___0___relu_1 = None
getattr_l__self___layer2___1___relu_1 = self.getattr_L__self___layer2___1___relu(iadd_3); iadd_3 = None
getattr_l__self___layer3___0___conv1 = self.getattr_L__self___layer3___0___conv1(getattr_l__self___layer2___1___relu_1)
getattr_l__self___layer3___0___downsample_0 = self.getattr_L__self___layer3___0___downsample_0(getattr_l__self___layer2___1___relu_1); getattr_l__self___layer2___1___relu_1 = None
getattr_l__self___layer3___0___bn1 = self.getattr_L__self___layer3___0___bn1(getattr_l__self___layer3___0___conv1); getattr_l__self___layer3___0___conv1 = None
getattr_l__self___layer3___0___downsample_1 = self.getattr_L__self___layer3___0___downsample_1(getattr_l__self___layer3___0___downsample_0); getattr_l__self___layer3___0___downsample_0 = None
getattr_l__self___layer3___0___relu = self.getattr_L__self___layer3___0___relu(getattr_l__self___layer3___0___bn1); getattr_l__self___layer3___0___bn1 = None
getattr_l__self___layer3___0___conv2 = self.getattr_L__self___layer3___0___conv2(getattr_l__self___layer3___0___relu); getattr_l__self___layer3___0___relu = None
getattr_l__self___layer3___0___bn2 = self.getattr_L__self___layer3___0___bn2(getattr_l__self___layer3___0___conv2); getattr_l__self___layer3___0___conv2 = None
getattr_l__self___layer3___0___bn2 += getattr_l__self___layer3___0___downsample_1; iadd_4 = getattr_l__self___layer3___0___bn2; getattr_l__self___layer3___0___bn2 = getattr_l__self___layer3___0___downsample_1 = None
getattr_l__self___layer3___0___relu_1 = self.getattr_L__self___layer3___0___relu(iadd_4); iadd_4 = None
getattr_l__self___layer3___1___conv1 = self.getattr_L__self___layer3___1___conv1(getattr_l__self___layer3___0___relu_1)
getattr_l__self___layer3___1___bn1 = self.getattr_L__self___layer3___1___bn1(getattr_l__self___layer3___1___conv1); getattr_l__self___layer3___1___conv1 = None
getattr_l__self___layer3___1___relu = self.getattr_L__self___layer3___1___relu(getattr_l__self___layer3___1___bn1); getattr_l__self___layer3___1___bn1 = None
getattr_l__self___layer3___1___conv2 = self.getattr_L__self___layer3___1___conv2(getattr_l__self___layer3___1___relu); getattr_l__self___layer3___1___relu = None
getattr_l__self___layer3___1___bn2 = self.getattr_L__self___layer3___1___bn2(getattr_l__self___layer3___1___conv2); getattr_l__self___layer3___1___conv2 = None
getattr_l__self___layer3___1___bn2 += getattr_l__self___layer3___0___relu_1; iadd_5 = getattr_l__self___layer3___1___bn2; getattr_l__self___layer3___1___bn2 = getattr_l__self___layer3___0___relu_1 = None
getattr_l__self___layer3___1___relu_1 = self.getattr_L__self___layer3___1___relu(iadd_5); iadd_5 = None
getattr_l__self___layer4___0___conv1 = self.getattr_L__self___layer4___0___conv1(getattr_l__self___layer3___1___relu_1)
getattr_l__self___layer4___0___downsample_0 = self.getattr_L__self___layer4___0___downsample_0(getattr_l__self___layer3___1___relu_1); getattr_l__self___layer3___1___relu_1 = None
getattr_l__self___layer4___0___bn1 = self.getattr_L__self___layer4___0___bn1(getattr_l__self___layer4___0___conv1); getattr_l__self___layer4___0___conv1 = None
getattr_l__self___layer4___0___downsample_1 = self.getattr_L__self___layer4___0___downsample_1(getattr_l__self___layer4___0___downsample_0); getattr_l__self___layer4___0___downsample_0 = None
getattr_l__self___layer4___0___relu = self.getattr_L__self___layer4___0___relu(getattr_l__self___layer4___0___bn1); getattr_l__self___layer4___0___bn1 = None
getattr_l__self___layer4___0___conv2 = self.getattr_L__self___layer4___0___conv2(getattr_l__self___layer4___0___relu); getattr_l__self___layer4___0___relu = None
getattr_l__self___layer4___0___bn2 = self.getattr_L__self___layer4___0___bn2(getattr_l__self___layer4___0___conv2); getattr_l__self___layer4___0___conv2 = None
getattr_l__self___layer4___0___bn2 += getattr_l__self___layer4___0___downsample_1; iadd_6 = getattr_l__self___layer4___0___bn2; getattr_l__self___layer4___0___bn2 = getattr_l__self___layer4___0___downsample_1 = None
getattr_l__self___layer4___0___relu_1 = self.getattr_L__self___layer4___0___relu(iadd_6); iadd_6 = None
getattr_l__self___layer4___1___conv1 = self.getattr_L__self___layer4___1___conv1(getattr_l__self___layer4___0___relu_1)
getattr_l__self___layer4___1___bn1 = self.getattr_L__self___layer4___1___bn1(getattr_l__self___layer4___1___conv1); getattr_l__self___layer4___1___conv1 = None
getattr_l__self___layer4___1___relu = self.getattr_L__self___layer4___1___relu(getattr_l__self___layer4___1___bn1); getattr_l__self___layer4___1___bn1 = None
getattr_l__self___layer4___1___conv2 = self.getattr_L__self___layer4___1___conv2(getattr_l__self___layer4___1___relu); getattr_l__self___layer4___1___relu = None
getattr_l__self___layer4___1___bn2 = self.getattr_L__self___layer4___1___bn2(getattr_l__self___layer4___1___conv2); getattr_l__self___layer4___1___conv2 = None
getattr_l__self___layer4___1___bn2 += getattr_l__self___layer4___0___relu_1; iadd_7 = getattr_l__self___layer4___1___bn2; getattr_l__self___layer4___1___bn2 = getattr_l__self___layer4___0___relu_1 = None
getattr_l__self___layer4___1___relu_1 = self.getattr_L__self___layer4___1___relu(iadd_7); iadd_7 = None
l__self___avgpool = self.L__self___avgpool(getattr_l__self___layer4___1___relu_1); getattr_l__self___layer4___1___relu_1 = None
flatten = torch.flatten(l__self___avgpool, 1); l__self___avgpool = None
l__self___fc = self.L__self___fc(flatten); flatten = None
return l__self___fc
openxla
once is also quite longer and looks nicer.. Let me play with create_functionalized_graph
a bit.
hmm, doesn't seem like
I was wrong, this function is being called when using aot backend... Trying to see if there is a way to bypass it.create_functionalized_graph
is being called.
AOTAutograd will do a few things in the inference path:
copy_()
nodes at the end of the graph.Are the warnings posted somewhere? I don't see a link to them. On the perf difference though - a bit of a shot in the dark, but I wonder of the copy_()
input mutations showing up later after the batchnorm call could cause any perf issues with XLA?
@JackCaoG the difference of the graph probably is due to the make_fx
call in create_functionalized_graph
: link .
warning message should be fixed by https://github.com/pytorch/pytorch/pull/107260. @GleasonK after my pr merged, you can use openxla_eval
while I am debugging this aot backend speed regression.
I was able to compare the IR and counter, two things stands on is that
xla::_native_batch_norm_legit
instead of xla::native_batch_norm
. I think this is fine, I vagully remember I implemented the _native_batch_norm_legit
by calling native_batch_norm
aot_backend
. For the resnet 18
aot_backend
Counter: xla::clone
Value: 103
Counter: xla::empty_symint
Value: 80
@bdhirsh do you where know where are these clone from?
non-aot
Counter: xla::clone
Value: 1
Counter: xla::empty_symint
Value: 120
IR for openxla(aot backend)
IR {
%0 = f32[64,3,7,7]{0,3,2,1} xla::device_data(), xla_shape=f32[64,3,7,7]{0,3,2,1}, device=TPU:0
%1 = f32[4,3,224,224]{3,2,1,0} xla::device_data(), xla_shape=f32[4,3,224,224]{3,2,1,0}, device=TPU:0
%2 = f32[4,64,112,112]{3,2,1,0} aten::convolution_overrideable(%1, %0), xla_shape=f32[4,64,112,112]{3,2,1,0}, stride=(2, 2), padding=(3, 3), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=0
%3 = f32[1000,512]{1,0} xla::device_data(), xla_shape=f32[1000,512]{1,0}, device=TPU:0
%4 = f32[512,1000]{0,1} aten::permute(%3), xla_shape=f32[512,1000]{0,1}, dims=(1, 0), ROOT=1
%5 = f32[] prim::Constant(), xla_shape=f32[], value=0
%6 = f32[0]{0} aten::expand(%5), xla_shape=f32[0]{0}, size=(0), ROOT=2
%7 = f32[] prim::Constant(), xla_shape=f32[], value=0
%8 = f32[0]{0} aten::expand(%7), xla_shape=f32[0]{0}, size=(0), ROOT=3
%9 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%10 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%11 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%12 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%13 = (f32[4,64,112,112]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%2, %12, %11, %10, %9), num_outputs=4, xla_shape=(f32[4,64,112,112]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%14 = f32[4,64,112,112]{3,2,1,0} aten::relu(%13.0), xla_shape=f32[4,64,112,112]{3,2,1,0}, ROOT=4
%15 = (f32[4,64,56,56]{3,2,1,0}, u32[4,64,56,56]{3,2,1,0}) aten::max_pool2d(%14), num_outputs=2, xla_shape=(f32[4,64,56,56]{3,2,1,0}, u32[4,64,56,56]{3,2,1,0}), spatial_dim_count=2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=0, ROOT=5
%16 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
%17 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%15.0, %16), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=6
%18 = f32[] prim::Constant(), xla_shape=f32[], value=0
%19 = f32[0]{0} aten::expand(%18), xla_shape=f32[0]{0}, size=(0), ROOT=7
%20 = f32[] prim::Constant(), xla_shape=f32[], value=0
%21 = f32[0]{0} aten::expand(%20), xla_shape=f32[0]{0}, size=(0), ROOT=8
%22 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%23 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%24 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%25 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%26 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%17, %25, %24, %23, %22), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%27 = f32[4,64,56,56]{3,2,1,0} aten::relu(%26.0), xla_shape=f32[4,64,56,56]{3,2,1,0}, ROOT=9
%28 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
%29 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%27, %28), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=10
%30 = f32[] prim::Constant(), xla_shape=f32[], value=0
%31 = f32[0]{0} aten::expand(%30), xla_shape=f32[0]{0}, size=(0), ROOT=11
%32 = f32[] prim::Constant(), xla_shape=f32[], value=0
%33 = f32[0]{0} aten::expand(%32), xla_shape=f32[0]{0}, size=(0), ROOT=12
%34 = f32[] prim::Constant(), xla_shape=f32[], value=1
%35 = f32[4,64,56,56]{3,2,1,0} aten::expand(%34), xla_shape=f32[4,64,56,56]{3,2,1,0}, size=(4, 64, 56, 56)
%36 = f32[4,64,56,56]{3,2,1,0} aten::mul(%15.0, %35), xla_shape=f32[4,64,56,56]{3,2,1,0}
%37 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%38 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%39 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%40 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%41 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%29, %40, %39, %38, %37), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%42 = f32[4,64,56,56]{3,2,1,0} aten::add(%41.0, %36), xla_shape=f32[4,64,56,56]{3,2,1,0}
%43 = f32[4,64,56,56]{3,2,1,0} aten::relu(%42), xla_shape=f32[4,64,56,56]{3,2,1,0}, ROOT=13
%44 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
%45 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%43, %44), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=14
%46 = f32[] prim::Constant(), xla_shape=f32[], value=0
%47 = f32[0]{0} aten::expand(%46), xla_shape=f32[0]{0}, size=(0), ROOT=15
%48 = f32[] prim::Constant(), xla_shape=f32[], value=0
%49 = f32[0]{0} aten::expand(%48), xla_shape=f32[0]{0}, size=(0), ROOT=16
%50 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%51 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%52 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%53 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%54 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%45, %53, %52, %51, %50), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%55 = f32[4,64,56,56]{3,2,1,0} aten::relu(%54.0), xla_shape=f32[4,64,56,56]{3,2,1,0}, ROOT=17
%56 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
%57 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%55, %56), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=18
%58 = f32[] prim::Constant(), xla_shape=f32[], value=0
%59 = f32[0]{0} aten::expand(%58), xla_shape=f32[0]{0}, size=(0), ROOT=19
%60 = f32[] prim::Constant(), xla_shape=f32[], value=0
%61 = f32[0]{0} aten::expand(%60), xla_shape=f32[0]{0}, size=(0), ROOT=20
%62 = f32[] prim::Constant(), xla_shape=f32[], value=1
%63 = f32[4,64,56,56]{3,2,1,0} aten::expand(%62), xla_shape=f32[4,64,56,56]{3,2,1,0}, size=(4, 64, 56, 56)
%64 = f32[4,64,56,56]{3,2,1,0} aten::mul(%43, %63), xla_shape=f32[4,64,56,56]{3,2,1,0}
%65 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%66 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%67 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%68 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%69 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%57, %68, %67, %66, %65), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%70 = f32[4,64,56,56]{3,2,1,0} aten::add(%69.0, %64), xla_shape=f32[4,64,56,56]{3,2,1,0}
%71 = f32[4,64,56,56]{3,2,1,0} aten::relu(%70), xla_shape=f32[4,64,56,56]{3,2,1,0}, ROOT=21
%72 = f32[128,64,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[128,64,3,3]{0,1,3,2}, device=TPU:0
%73 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%71, %72), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=22
%74 = f32[128,64,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[128,64,1,1]{0,1,3,2}, device=TPU:0
%75 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%71, %74), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=23
%76 = f32[] prim::Constant(), xla_shape=f32[], value=0
%77 = f32[0]{0} aten::expand(%76), xla_shape=f32[0]{0}, size=(0), ROOT=24
%78 = f32[] prim::Constant(), xla_shape=f32[], value=0
%79 = f32[0]{0} aten::expand(%78), xla_shape=f32[0]{0}, size=(0), ROOT=25
%80 = f32[] prim::Constant(), xla_shape=f32[], value=0
%81 = f32[0]{0} aten::expand(%80), xla_shape=f32[0]{0}, size=(0), ROOT=26
%82 = f32[] prim::Constant(), xla_shape=f32[], value=0
%83 = f32[0]{0} aten::expand(%82), xla_shape=f32[0]{0}, size=(0), ROOT=27
%84 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%85 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%86 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%87 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%88 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%73, %87, %86, %85, %84), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%89 = f32[4,128,28,28]{3,2,1,0} aten::relu(%88.0), xla_shape=f32[4,128,28,28]{3,2,1,0}, ROOT=28
%90 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
%91 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%89, %90), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=29
%92 = f32[] prim::Constant(), xla_shape=f32[], value=0
%93 = f32[0]{0} aten::expand(%92), xla_shape=f32[0]{0}, size=(0), ROOT=30
%94 = f32[] prim::Constant(), xla_shape=f32[], value=0
%95 = f32[0]{0} aten::expand(%94), xla_shape=f32[0]{0}, size=(0), ROOT=31
%96 = f32[] prim::Constant(), xla_shape=f32[], value=1
%97 = f32[4,128,28,28]{3,2,1,0} aten::expand(%96), xla_shape=f32[4,128,28,28]{3,2,1,0}, size=(4, 128, 28, 28)
%98 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%99 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%100 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%101 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%102 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%75, %101, %100, %99, %98), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%103 = f32[4,128,28,28]{3,2,1,0} aten::mul(%102.0, %97), xla_shape=f32[4,128,28,28]{3,2,1,0}
%104 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%105 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%106 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%107 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%108 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%91, %107, %106, %105, %104), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%109 = f32[4,128,28,28]{3,2,1,0} aten::add(%108.0, %103), xla_shape=f32[4,128,28,28]{3,2,1,0}
%110 = f32[4,128,28,28]{3,2,1,0} aten::relu(%109), xla_shape=f32[4,128,28,28]{3,2,1,0}, ROOT=32
%111 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
%112 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%110, %111), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=33
%113 = f32[] prim::Constant(), xla_shape=f32[], value=0
%114 = f32[0]{0} aten::expand(%113), xla_shape=f32[0]{0}, size=(0), ROOT=34
%115 = f32[] prim::Constant(), xla_shape=f32[], value=0
%116 = f32[0]{0} aten::expand(%115), xla_shape=f32[0]{0}, size=(0), ROOT=35
%117 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%118 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%119 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%120 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%121 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%112, %120, %119, %118, %117), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%122 = f32[4,128,28,28]{3,2,1,0} aten::relu(%121.0), xla_shape=f32[4,128,28,28]{3,2,1,0}, ROOT=36
%123 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
%124 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%122, %123), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=37
%125 = f32[] prim::Constant(), xla_shape=f32[], value=0
%126 = f32[0]{0} aten::expand(%125), xla_shape=f32[0]{0}, size=(0), ROOT=38
%127 = f32[] prim::Constant(), xla_shape=f32[], value=0
%128 = f32[0]{0} aten::expand(%127), xla_shape=f32[0]{0}, size=(0), ROOT=39
%129 = f32[] prim::Constant(), xla_shape=f32[], value=1
%130 = f32[4,128,28,28]{3,2,1,0} aten::expand(%129), xla_shape=f32[4,128,28,28]{3,2,1,0}, size=(4, 128, 28, 28)
%131 = f32[4,128,28,28]{3,2,1,0} aten::mul(%110, %130), xla_shape=f32[4,128,28,28]{3,2,1,0}
%132 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%133 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%134 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%135 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%136 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%124, %135, %134, %133, %132), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%137 = f32[4,128,28,28]{3,2,1,0} aten::add(%136.0, %131), xla_shape=f32[4,128,28,28]{3,2,1,0}
%138 = f32[4,128,28,28]{3,2,1,0} aten::relu(%137), xla_shape=f32[4,128,28,28]{3,2,1,0}, ROOT=40
%139 = f32[256,128,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[256,128,3,3]{0,1,3,2}, device=TPU:0
%140 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%138, %139), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=41
%141 = f32[256,128,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[256,128,1,1]{0,1,3,2}, device=TPU:0
%142 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%138, %141), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=42
%143 = f32[] prim::Constant(), xla_shape=f32[], value=0
%144 = f32[0]{0} aten::expand(%143), xla_shape=f32[0]{0}, size=(0), ROOT=43
%145 = f32[] prim::Constant(), xla_shape=f32[], value=0
%146 = f32[0]{0} aten::expand(%145), xla_shape=f32[0]{0}, size=(0), ROOT=44
%147 = f32[] prim::Constant(), xla_shape=f32[], value=0
%148 = f32[0]{0} aten::expand(%147), xla_shape=f32[0]{0}, size=(0), ROOT=45
%149 = f32[] prim::Constant(), xla_shape=f32[], value=0
%150 = f32[0]{0} aten::expand(%149), xla_shape=f32[0]{0}, size=(0), ROOT=46
%151 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%152 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%153 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%154 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%155 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%140, %154, %153, %152, %151), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%156 = f32[4,256,14,14]{3,2,1,0} aten::relu(%155.0), xla_shape=f32[4,256,14,14]{3,2,1,0}, ROOT=47
%157 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
%158 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%156, %157), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=48
%159 = f32[] prim::Constant(), xla_shape=f32[], value=0
%160 = f32[0]{0} aten::expand(%159), xla_shape=f32[0]{0}, size=(0), ROOT=49
%161 = f32[] prim::Constant(), xla_shape=f32[], value=0
%162 = f32[0]{0} aten::expand(%161), xla_shape=f32[0]{0}, size=(0), ROOT=50
%163 = f32[] prim::Constant(), xla_shape=f32[], value=1
%164 = f32[4,256,14,14]{3,2,1,0} aten::expand(%163), xla_shape=f32[4,256,14,14]{3,2,1,0}, size=(4, 256, 14, 14)
%165 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%166 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%167 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%168 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%169 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%142, %168, %167, %166, %165), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%170 = f32[4,256,14,14]{3,2,1,0} aten::mul(%169.0, %164), xla_shape=f32[4,256,14,14]{3,2,1,0}
%171 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%172 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%173 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%174 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%175 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%158, %174, %173, %172, %171), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%176 = f32[4,256,14,14]{3,2,1,0} aten::add(%175.0, %170), xla_shape=f32[4,256,14,14]{3,2,1,0}
%177 = f32[4,256,14,14]{3,2,1,0} aten::relu(%176), xla_shape=f32[4,256,14,14]{3,2,1,0}, ROOT=51
%178 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
%179 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%177, %178), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=52
%180 = f32[] prim::Constant(), xla_shape=f32[], value=0
%181 = f32[0]{0} aten::expand(%180), xla_shape=f32[0]{0}, size=(0), ROOT=53
%182 = f32[] prim::Constant(), xla_shape=f32[], value=0
%183 = f32[0]{0} aten::expand(%182), xla_shape=f32[0]{0}, size=(0), ROOT=54
%184 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%185 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%186 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%187 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%188 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%179, %187, %186, %185, %184), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%189 = f32[4,256,14,14]{3,2,1,0} aten::relu(%188.0), xla_shape=f32[4,256,14,14]{3,2,1,0}, ROOT=55
%190 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
%191 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%189, %190), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=56
%192 = f32[] prim::Constant(), xla_shape=f32[], value=0
%193 = f32[0]{0} aten::expand(%192), xla_shape=f32[0]{0}, size=(0), ROOT=57
%194 = f32[] prim::Constant(), xla_shape=f32[], value=0
%195 = f32[0]{0} aten::expand(%194), xla_shape=f32[0]{0}, size=(0), ROOT=58
%196 = f32[] prim::Constant(), xla_shape=f32[], value=1
%197 = f32[4,256,14,14]{3,2,1,0} aten::expand(%196), xla_shape=f32[4,256,14,14]{3,2,1,0}, size=(4, 256, 14, 14)
%198 = f32[4,256,14,14]{3,2,1,0} aten::mul(%177, %197), xla_shape=f32[4,256,14,14]{3,2,1,0}
%199 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%200 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%201 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%202 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%203 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%191, %202, %201, %200, %199), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%204 = f32[4,256,14,14]{3,2,1,0} aten::add(%203.0, %198), xla_shape=f32[4,256,14,14]{3,2,1,0}
%205 = f32[4,256,14,14]{3,2,1,0} aten::relu(%204), xla_shape=f32[4,256,14,14]{3,2,1,0}, ROOT=59
%206 = f32[512,256,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[512,256,3,3]{0,1,3,2}, device=TPU:0
%207 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%205, %206), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=60
%208 = f32[512,256,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[512,256,1,1]{0,1,3,2}, device=TPU:0
%209 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%205, %208), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=61
%210 = f32[] prim::Constant(), xla_shape=f32[], value=0
%211 = f32[0]{0} aten::expand(%210), xla_shape=f32[0]{0}, size=(0), ROOT=62
%212 = f32[] prim::Constant(), xla_shape=f32[], value=0
%213 = f32[0]{0} aten::expand(%212), xla_shape=f32[0]{0}, size=(0), ROOT=63
%214 = f32[] prim::Constant(), xla_shape=f32[], value=0
%215 = f32[0]{0} aten::expand(%214), xla_shape=f32[0]{0}, size=(0), ROOT=64
%216 = f32[] prim::Constant(), xla_shape=f32[], value=0
%217 = f32[0]{0} aten::expand(%216), xla_shape=f32[0]{0}, size=(0), ROOT=65
%218 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%219 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%220 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%221 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%222 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%207, %221, %220, %219, %218), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%223 = f32[4,512,7,7]{3,2,1,0} aten::relu(%222.0), xla_shape=f32[4,512,7,7]{3,2,1,0}, ROOT=66
%224 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
%225 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%223, %224), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=67
%226 = f32[] prim::Constant(), xla_shape=f32[], value=0
%227 = f32[0]{0} aten::expand(%226), xla_shape=f32[0]{0}, size=(0), ROOT=68
%228 = f32[] prim::Constant(), xla_shape=f32[], value=0
%229 = f32[0]{0} aten::expand(%228), xla_shape=f32[0]{0}, size=(0), ROOT=69
%230 = f32[] prim::Constant(), xla_shape=f32[], value=1
%231 = f32[4,512,7,7]{3,2,1,0} aten::expand(%230), xla_shape=f32[4,512,7,7]{3,2,1,0}, size=(4, 512, 7, 7)
%232 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%233 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%234 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%235 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%236 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%209, %235, %234, %233, %232), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%237 = f32[4,512,7,7]{3,2,1,0} aten::mul(%236.0, %231), xla_shape=f32[4,512,7,7]{3,2,1,0}
%238 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%239 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%240 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%241 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%242 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%225, %241, %240, %239, %238), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%243 = f32[4,512,7,7]{3,2,1,0} aten::add(%242.0, %237), xla_shape=f32[4,512,7,7]{3,2,1,0}
%244 = f32[4,512,7,7]{3,2,1,0} aten::relu(%243), xla_shape=f32[4,512,7,7]{3,2,1,0}, ROOT=70
%245 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
%246 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%244, %245), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=71
%247 = f32[] prim::Constant(), xla_shape=f32[], value=0
%248 = f32[0]{0} aten::expand(%247), xla_shape=f32[0]{0}, size=(0), ROOT=72
%249 = f32[] prim::Constant(), xla_shape=f32[], value=0
%250 = f32[0]{0} aten::expand(%249), xla_shape=f32[0]{0}, size=(0), ROOT=73
%251 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%252 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%253 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%254 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%255 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%246, %254, %253, %252, %251), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%256 = f32[4,512,7,7]{3,2,1,0} aten::relu(%255.0), xla_shape=f32[4,512,7,7]{3,2,1,0}, ROOT=74
%257 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
%258 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%256, %257), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1, ROOT=75
%259 = f32[] prim::Constant(), xla_shape=f32[], value=0
%260 = f32[0]{0} aten::expand(%259), xla_shape=f32[0]{0}, size=(0), ROOT=76
%261 = f32[] prim::Constant(), xla_shape=f32[], value=0
%262 = f32[0]{0} aten::expand(%261), xla_shape=f32[0]{0}, size=(0), ROOT=77
%263 = f32[] prim::Constant(), xla_shape=f32[], value=1
%264 = f32[4,512,7,7]{3,2,1,0} aten::expand(%263), xla_shape=f32[4,512,7,7]{3,2,1,0}, size=(4, 512, 7, 7)
%265 = f32[4,512,7,7]{3,2,1,0} aten::mul(%244, %264), xla_shape=f32[4,512,7,7]{3,2,1,0}
%266 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%267 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%268 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%269 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%270 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%258, %269, %268, %267, %266), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%271 = f32[4,512,7,7]{3,2,1,0} aten::add(%270.0, %265), xla_shape=f32[4,512,7,7]{3,2,1,0}
%272 = f32[4,512,7,7]{3,2,1,0} aten::relu(%271), xla_shape=f32[4,512,7,7]{3,2,1,0}, ROOT=78
%273 = f32[4,512,1,1]{3,2,1,0} aten::mean(%272), xla_shape=f32[4,512,1,1]{3,2,1,0}, dimensions=(3, 2), keep_reduced_dimensions=1, dtype=-1
%274 = f32[4,512]{1,0} aten::view(%273), xla_shape=f32[4,512]{1,0}, output_size=(4, 512), ROOT=79
%275 = f32[1000]{0} xla::device_data(), xla_shape=f32[1000]{0}, device=TPU:0
%276 = f32[4,1000]{1,0} aten::addmm(%274, %4, %275), xla_shape=f32[4,1000]{1,0}, ROOT=80
}
non-aot
IR {
%0 = f32[1000]{0} xla::device_data(), xla_shape=f32[1000]{0}, device=TPU:0
%1 = f32[1000,512]{1,0} xla::device_data(), xla_shape=f32[1000,512]{1,0}, device=TPU:0
%2 = f32[512,1000]{0,1} aten::permute(%1), xla_shape=f32[512,1000]{0,1}, dims=(1, 0)
%3 = f32[] prim::Constant(), xla_shape=f32[], value=1
%4 = f32[4,512,7,7]{3,2,1,0} aten::expand(%3), xla_shape=f32[4,512,7,7]{3,2,1,0}, size=(4, 512, 7, 7)
%5 = f32[] prim::Constant(), xla_shape=f32[], value=1
%6 = f32[4,512,7,7]{3,2,1,0} aten::expand(%5), xla_shape=f32[4,512,7,7]{3,2,1,0}, size=(4, 512, 7, 7)
%7 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%8 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%9 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%10 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%11 = f32[512,256,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[512,256,1,1]{0,1,3,2}, device=TPU:0
%12 = f32[] prim::Constant(), xla_shape=f32[], value=1
%13 = f32[4,256,14,14]{3,2,1,0} aten::expand(%12), xla_shape=f32[4,256,14,14]{3,2,1,0}, size=(4, 256, 14, 14)
%14 = f32[] prim::Constant(), xla_shape=f32[], value=1
%15 = f32[4,256,14,14]{3,2,1,0} aten::expand(%14), xla_shape=f32[4,256,14,14]{3,2,1,0}, size=(4, 256, 14, 14)
%16 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%17 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%18 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%19 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%20 = f32[256,128,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[256,128,1,1]{0,1,3,2}, device=TPU:0
%21 = f32[] prim::Constant(), xla_shape=f32[], value=1
%22 = f32[4,128,28,28]{3,2,1,0} aten::expand(%21), xla_shape=f32[4,128,28,28]{3,2,1,0}, size=(4, 128, 28, 28)
%23 = f32[] prim::Constant(), xla_shape=f32[], value=1
%24 = f32[4,128,28,28]{3,2,1,0} aten::expand(%23), xla_shape=f32[4,128,28,28]{3,2,1,0}, size=(4, 128, 28, 28)
%25 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%26 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%27 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%28 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%29 = f32[128,64,1,1]{0,1,3,2} xla::device_data(), xla_shape=f32[128,64,1,1]{0,1,3,2}, device=TPU:0
%30 = f32[] prim::Constant(), xla_shape=f32[], value=1
%31 = f32[4,64,56,56]{3,2,1,0} aten::expand(%30), xla_shape=f32[4,64,56,56]{3,2,1,0}, size=(4, 64, 56, 56)
%32 = f32[] prim::Constant(), xla_shape=f32[], value=1
%33 = f32[4,64,56,56]{3,2,1,0} aten::expand(%32), xla_shape=f32[4,64,56,56]{3,2,1,0}, size=(4, 64, 56, 56)
%34 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%35 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%36 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%37 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%38 = f32[64,3,7,7]{0,3,2,1} xla::device_data(), xla_shape=f32[64,3,7,7]{0,3,2,1}, device=TPU:0
%39 = f32[4,3,224,224]{3,2,1,0} xla::device_data(), xla_shape=f32[4,3,224,224]{3,2,1,0}, device=TPU:0
%40 = f32[4,64,112,112]{3,2,1,0} aten::convolution_overrideable(%39, %38), xla_shape=f32[4,64,112,112]{3,2,1,0}, stride=(2, 2), padding=(3, 3), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%41 = (f32[4,64,112,112]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%40, %37, %36, %35, %34), num_outputs=4, xla_shape=(f32[4,64,112,112]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%42 = f32[4,64,112,112]{3,2,1,0} aten::relu(%41.0), xla_shape=f32[4,64,112,112]{3,2,1,0}
%43 = (f32[4,64,56,56]{3,2,1,0}, u32[4,64,56,56]{3,2,1,0}) aten::max_pool2d(%42), num_outputs=2, xla_shape=(f32[4,64,56,56]{3,2,1,0}, u32[4,64,56,56]{3,2,1,0}), spatial_dim_count=2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), ceil_mode=0
%44 = f32[4,64,56,56]{3,2,1,0} aten::mul(%43.0, %33), xla_shape=f32[4,64,56,56]{3,2,1,0}
%45 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%46 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%47 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%48 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%49 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
%50 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%51 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%52 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%53 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%54 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
%55 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%43.0, %54), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%56 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%55, %53, %52, %51, %50), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%57 = f32[4,64,56,56]{3,2,1,0} aten::relu(%56.0), xla_shape=f32[4,64,56,56]{3,2,1,0}
%58 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%57, %49), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%59 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%58, %48, %47, %46, %45), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%60 = f32[4,64,56,56]{3,2,1,0} aten::add(%59.0, %44), xla_shape=f32[4,64,56,56]{3,2,1,0}
%61 = f32[4,64,56,56]{3,2,1,0} aten::relu(%60), xla_shape=f32[4,64,56,56]{3,2,1,0}
%62 = f32[4,64,56,56]{3,2,1,0} aten::mul(%61, %31), xla_shape=f32[4,64,56,56]{3,2,1,0}
%63 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%64 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%65 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%66 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%67 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
%68 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%69 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%70 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%71 = f32[64]{0} xla::device_data(), xla_shape=f32[64]{0}, device=TPU:0
%72 = f32[64,64,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[64,64,3,3]{1,0,3,2}, device=TPU:0
%73 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%61, %72), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%74 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%73, %71, %70, %69, %68), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%75 = f32[4,64,56,56]{3,2,1,0} aten::relu(%74.0), xla_shape=f32[4,64,56,56]{3,2,1,0}
%76 = f32[4,64,56,56]{3,2,1,0} aten::convolution_overrideable(%75, %67), xla_shape=f32[4,64,56,56]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%77 = (f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}) aten::native_batch_norm(%76, %66, %65, %64, %63), num_outputs=4, xla_shape=(f32[4,64,56,56]{3,2,1,0}, f32[64]{0}, f32[64]{0}, f32[64]{0}), training=0, eps=1e-05
%78 = f32[4,64,56,56]{3,2,1,0} aten::add(%77.0, %62), xla_shape=f32[4,64,56,56]{3,2,1,0}
%79 = f32[4,64,56,56]{3,2,1,0} aten::relu(%78), xla_shape=f32[4,64,56,56]{3,2,1,0}
%80 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%79, %29), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%81 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%80, %28, %27, %26, %25), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%82 = f32[4,128,28,28]{3,2,1,0} aten::mul(%81.0, %24), xla_shape=f32[4,128,28,28]{3,2,1,0}
%83 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%84 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%85 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%86 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%87 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
%88 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%89 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%90 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%91 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%92 = f32[128,64,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[128,64,3,3]{0,1,3,2}, device=TPU:0
%93 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%79, %92), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%94 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%93, %91, %90, %89, %88), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%95 = f32[4,128,28,28]{3,2,1,0} aten::relu(%94.0), xla_shape=f32[4,128,28,28]{3,2,1,0}
%96 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%95, %87), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%97 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%96, %86, %85, %84, %83), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%98 = f32[4,128,28,28]{3,2,1,0} aten::add(%97.0, %82), xla_shape=f32[4,128,28,28]{3,2,1,0}
%99 = f32[4,128,28,28]{3,2,1,0} aten::relu(%98), xla_shape=f32[4,128,28,28]{3,2,1,0}
%100 = f32[4,128,28,28]{3,2,1,0} aten::mul(%99, %22), xla_shape=f32[4,128,28,28]{3,2,1,0}
%101 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%102 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%103 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%104 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%105 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
%106 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%107 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%108 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%109 = f32[128]{0} xla::device_data(), xla_shape=f32[128]{0}, device=TPU:0
%110 = f32[128,128,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[128,128,3,3]{1,0,3,2}, device=TPU:0
%111 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%99, %110), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%112 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%111, %109, %108, %107, %106), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%113 = f32[4,128,28,28]{3,2,1,0} aten::relu(%112.0), xla_shape=f32[4,128,28,28]{3,2,1,0}
%114 = f32[4,128,28,28]{3,2,1,0} aten::convolution_overrideable(%113, %105), xla_shape=f32[4,128,28,28]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%115 = (f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}) aten::native_batch_norm(%114, %104, %103, %102, %101), num_outputs=4, xla_shape=(f32[4,128,28,28]{3,2,1,0}, f32[128]{0}, f32[128]{0}, f32[128]{0}), training=0, eps=1e-05
%116 = f32[4,128,28,28]{3,2,1,0} aten::add(%115.0, %100), xla_shape=f32[4,128,28,28]{3,2,1,0}
%117 = f32[4,128,28,28]{3,2,1,0} aten::relu(%116), xla_shape=f32[4,128,28,28]{3,2,1,0}
%118 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%117, %20), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%119 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%118, %19, %18, %17, %16), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%120 = f32[4,256,14,14]{3,2,1,0} aten::mul(%119.0, %15), xla_shape=f32[4,256,14,14]{3,2,1,0}
%121 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%122 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%123 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%124 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%125 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
%126 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%127 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%128 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%129 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%130 = f32[256,128,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[256,128,3,3]{0,1,3,2}, device=TPU:0
%131 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%117, %130), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%132 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%131, %129, %128, %127, %126), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%133 = f32[4,256,14,14]{3,2,1,0} aten::relu(%132.0), xla_shape=f32[4,256,14,14]{3,2,1,0}
%134 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%133, %125), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%135 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%134, %124, %123, %122, %121), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%136 = f32[4,256,14,14]{3,2,1,0} aten::add(%135.0, %120), xla_shape=f32[4,256,14,14]{3,2,1,0}
%137 = f32[4,256,14,14]{3,2,1,0} aten::relu(%136), xla_shape=f32[4,256,14,14]{3,2,1,0}
%138 = f32[4,256,14,14]{3,2,1,0} aten::mul(%137, %13), xla_shape=f32[4,256,14,14]{3,2,1,0}
%139 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%140 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%141 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%142 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%143 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
%144 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%145 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%146 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%147 = f32[256]{0} xla::device_data(), xla_shape=f32[256]{0}, device=TPU:0
%148 = f32[256,256,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[256,256,3,3]{1,0,3,2}, device=TPU:0
%149 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%137, %148), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%150 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%149, %147, %146, %145, %144), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%151 = f32[4,256,14,14]{3,2,1,0} aten::relu(%150.0), xla_shape=f32[4,256,14,14]{3,2,1,0}
%152 = f32[4,256,14,14]{3,2,1,0} aten::convolution_overrideable(%151, %143), xla_shape=f32[4,256,14,14]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%153 = (f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}) aten::native_batch_norm(%152, %142, %141, %140, %139), num_outputs=4, xla_shape=(f32[4,256,14,14]{3,2,1,0}, f32[256]{0}, f32[256]{0}, f32[256]{0}), training=0, eps=1e-05
%154 = f32[4,256,14,14]{3,2,1,0} aten::add(%153.0, %138), xla_shape=f32[4,256,14,14]{3,2,1,0}
%155 = f32[4,256,14,14]{3,2,1,0} aten::relu(%154), xla_shape=f32[4,256,14,14]{3,2,1,0}
%156 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%155, %11), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(2, 2), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%157 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%156, %10, %9, %8, %7), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%158 = f32[4,512,7,7]{3,2,1,0} aten::mul(%157.0, %6), xla_shape=f32[4,512,7,7]{3,2,1,0}
%159 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%160 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%161 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%162 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%163 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
%164 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%165 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%166 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%167 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%168 = f32[512,256,3,3]{0,1,3,2} xla::device_data(), xla_shape=f32[512,256,3,3]{0,1,3,2}, device=TPU:0
%169 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%155, %168), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(2, 2), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%170 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%169, %167, %166, %165, %164), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%171 = f32[4,512,7,7]{3,2,1,0} aten::relu(%170.0), xla_shape=f32[4,512,7,7]{3,2,1,0}
%172 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%171, %163), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%173 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%172, %162, %161, %160, %159), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%174 = f32[4,512,7,7]{3,2,1,0} aten::add(%173.0, %158), xla_shape=f32[4,512,7,7]{3,2,1,0}
%175 = f32[4,512,7,7]{3,2,1,0} aten::relu(%174), xla_shape=f32[4,512,7,7]{3,2,1,0}
%176 = f32[4,512,7,7]{3,2,1,0} aten::mul(%175, %4), xla_shape=f32[4,512,7,7]{3,2,1,0}
%177 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%178 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%179 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%180 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%181 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
%182 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%183 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%184 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%185 = f32[512]{0} xla::device_data(), xla_shape=f32[512]{0}, device=TPU:0
%186 = f32[512,512,3,3]{1,0,3,2} xla::device_data(), xla_shape=f32[512,512,3,3]{1,0,3,2}, device=TPU:0
%187 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%175, %186), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%188 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%187, %185, %184, %183, %182), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%189 = f32[4,512,7,7]{3,2,1,0} aten::relu(%188.0), xla_shape=f32[4,512,7,7]{3,2,1,0}
%190 = f32[4,512,7,7]{3,2,1,0} aten::convolution_overrideable(%189, %181), xla_shape=f32[4,512,7,7]{3,2,1,0}, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%191 = (f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}) aten::native_batch_norm(%190, %180, %179, %178, %177), num_outputs=4, xla_shape=(f32[4,512,7,7]{3,2,1,0}, f32[512]{0}, f32[512]{0}, f32[512]{0}), training=0, eps=1e-05
%192 = f32[4,512,7,7]{3,2,1,0} aten::add(%191.0, %176), xla_shape=f32[4,512,7,7]{3,2,1,0}
%193 = f32[4,512,7,7]{3,2,1,0} aten::relu(%192), xla_shape=f32[4,512,7,7]{3,2,1,0}
%194 = f32[4,512,1,1]{3,2,1,0} aten::mean(%193), xla_shape=f32[4,512,1,1]{3,2,1,0}, dimensions=(3, 2), keep_reduced_dimensions=1, dtype=-1
%195 = f32[4,512]{1,0} aten::view(%194), xla_shape=f32[4,512]{1,0}, output_size=(4, 512)
%196 = f32[4,1000]{1,0} aten::addmm(%195, %2, %0), xla_shape=f32[4,1000]{1,0}, ROOT=0
}
another thing I noticed is that there are a lot of (40)
%18 = f32[] prim::Constant(), xla_shape=f32[], value=0
%19 = f32[0]{0} aten::expand(%18), xla_shape=f32[0]{0}, size=(0), ROOT=7
%20 = f32[] prim::Constant(), xla_shape=f32[], value=0
%21 = f32[0]{0} aten::expand(%20), xla_shape=f32[0]{0}, size=(0), ROOT=8
in the IR and they are not used by any other op, it is just expanding a constant to size [1]
..
yea I can confirm that in https://github.com/pytorch/xla/blob/941a4c489add37db9d33d8f9dc3f263a47c3feed/torch_xla/core/dynamo_bridge.py#L240
where we execute the fx graph passed down by dynamo, aot backend has 81 result while non-aot backend has 4 result. Not sure what are these additional results that aot backend requires.
hmm @JackCaoG is it possible to figure out which aten ops in the graph the extra xla::clone
calls are coming from? 103 seems like a lot. Going off of my batch norm theory, I would have only expected an extra at::copy_()
for every running_mean
and running_var
buffer, and I'm pretty sure there aren't 100 of then.
another thing I noticed is that there are a lot of (40)
%18 = f32[] prim::Constant(), xla_shape=f32[], value=0
%19 = f32[0]{0} aten::expand(%18), xla_shape=f32[0]{0}, size=(0), ROOT=7
%20 = f32[] prim::Constant(), xla_shape=f32[], value=0
%21 = f32[0]{0} aten::expand(%20), xla_shape=f32[0]{0}, size=(0), ROOT=8
hmm... that's creating a bunch of zero-sized tensors. It's not clear to me at all where that would be coming from through AOTAutograd. Is it possible to figure out which aten op LazyTensor dispatched on, that eventually desugared into those ops showing up in the graph?
yea, let me do a debug build and try to figure out what happened.
ok I think I know what are those tensors. It seems like when doing the fwd call of the resnet18, there is one real output, but aot autograd need to save a bunch of stuff for the bwd
(Pdb) len(tensors_saved_for_backwards)
161
(Pdb) len(fw_outs)
162
that being said I still don't know why we return a bunch of expand value..
@bdhirsh I was able to confirm that the expand of a 0 size tensor are from empty_symint
and they are being captured as tensors_saved_for_backwards
for aot-autograd.
@JackCaoG AOTAutograd tries to figure out if you're "doing inference", and if it detects that, then it will run an inference code path (that should involve not saving any tensors for backward). The inference check lives here. As long as your code is run under a no_grad()
, then AOTAutograd should compile an inference-only graph.
Can we test that out in the repro?
with torch.no_grad():
solved my issue, the number of output becomes 4 again and HLO become much shorter. @GleasonK I don't know if there is an easy way to model inside no_grad
@bdhirsh what I did was wrapping everything
with torch.no_grad():
resnet18 = torchvision.models.resnet18()
resnet18.eval()
xla_resnet18 = torchvision.models.resnet18()
xla_resnet18.load_state_dict(resnet18.state_dict())
xla_resnet18.to(device)
xla_resnet18.eval()
# materalize the fake data for test purpose
xm.mark_step()
xm.wait_device_ops()
met.clear_all()
for data, _ in loader:
dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')
do we only need to wrap it on execution (dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')
),
init
resnet18 = torchvision.models.resnet18()
resnet18.eval()
xla_resnet18 = torchvision.models.resnet18()
xla_resnet18.load_state_dict(resnet18.state_dict())
xla_resnet18.to(device)
or both?
I was able to confirm just
dynamo_resnet18 = torch.compile(xla_resnet18, backend='openxla')
with torch.no_grad():
output = dynamo_resnet18(data)
works
@GleasonK Let me know if you can verified that regression is gone, then we can close this issue
Won't be able to verify for awhile (OOO). I'm confident this should work though. Happy to close and re-open if I find issues in the future. Thanks for the quick investigation!
Also think it's worth more discussion on if this should be the recommended solution to the problem. Doesn't have to happen in this ticket though
sg, I will close this issue for now.
actually I found even with torch.no_grad
non-aot-backend is still faster.. let me look into why..
I was looking at llama profile for openxla
and openxla_eval
, I found that openxla
triggered a additional compuation for mark_step
what it looks like is it is trying to extract a value of a s64 integer.. looking.
looking at the IR.. hmm this might be a llama thing.. it seems to be a bool.
[ScheduleSyncTensorsGraph]
TensorsGraphInfo:
generate (/src/llama/llama/generation.py:328)
decorate_context (/src/pytorch/torch/utils/_contextlib.py:115)
text_completion (/src/llama/llama/generation.py:367)
main (example_text_completion.py:70)
mp_main (example_text_completion.py:124)
_CallAndUpdateTrace (/usr/local/lib/python3.8/site-packages/fire/core.py:691)
_Fire (/usr/local/lib/python3.8/site-packages/fire/core.py:475)
Fire (/usr/local/lib/python3.8/site-packages/fire/core.py:141)
<module> (example_text_completion.py:130)
Hashes: (19aedcce2035cd2efc8e53e8a3a2b99b)
## BEGIN_GRAPH
IR {
%0 = pred[1]{0} xla::device_data(), xla_shape=pred[1]{0}, device=SPMD:0
%1 = pred[1]{0} xla::generic_slice(%0), xla_shape=pred[1]{0}, base_indices=(0), sizes=(1)
%2 = pred[] aten::view(%1), xla_shape=pred[], output_size=(), ROOT=0
}
ok I found the problem.. but I am still trying to figure out what happened here
the issue is caused by
if all(eos_reached):
break
and eos_reached
is a tensor with
IR {
%0 = pred[1]{0} xla::device_data(), xla_shape=pred[1]{0}, device=SPMD:0, ROOT=0
}
(Pdb) eos_reached.size()
torch.Size([1])
(Pdb) eos_reached.dtype
torch.bool
what I didn't expected is that all(eos_reached)
actually triggered an execution.. What I don't understand is how is using aot_backend
of dynamo has anything to do with this extra execution...
ok nvm, the real issue is from somewhere else
[ScheduleSyncTensorsGraph]
TensorsGraphInfo:
mark_step (/src/pytorch/xla/torch_xla/core/xla_model.py:815)
optimized_mod (/src/pytorch/xla/torch_xla/core/dynamo_bridge.py:315)
forward (<eval_with_key>.11:5)
_call_impl (/src/pytorch/torch/nn/modules/module.py:1527)
_wrapped_call_impl (/src/pytorch/torch/nn/modules/module.py:1518)
__call__ (/src/pytorch/torch/fx/graph_module.py:274)
call_wrapped (/src/pytorch/torch/fx/graph_module.py:678)
fwd (/src/pytorch/torch/_dynamo/backends/torchxla.py:46)
g (/src/pytorch/torch/_functorch/aot_autograd.py:1483)
rng_functionalization_wrapper (/src/pytorch/torch/_functorch/aot_autograd.py:1595)
call_func_with_args (/src/pytorch/torch/_functorch/aot_autograd.py:1507)
runtime_wrapper (/src/pytorch/torch/_functorch/aot_autograd.py:2441)
g (/src/pytorch/torch/_functorch/aot_autograd.py:1483)
forward (/src/pytorch/torch/_functorch/aot_autograd.py:3810)
inner (/src/pytorch/torch/_dynamo/external_utils.py:17)
_fn (/src/pytorch/torch/_dynamo/eval_frame.py:321)
_generate_one_token (/src/llama/llama/generation.py:197)
_fn (/src/pytorch/torch/_dynamo/eval_frame.py:321)
generate (/src/llama/llama/generation.py:319)
decorate_context (/src/pytorch/torch/utils/_contextlib.py:115)
text_completion (/src/llama/llama/generation.py:373)
main (example_text_completion.py:70)
mp_main (example_text_completion.py:124)
_CallAndUpdateTrace (/usr/local/lib/python3.8/site-packages/fire/core.py:691)
_Fire (/usr/local/lib/python3.8/site-packages/fire/core.py:475)
Fire (/usr/local/lib/python3.8/site-packages/fire/core.py:141)
<module> (example_text_completion.py:130)
Hashes: (8cd46c5be00db4f982f612eb9ccdddc3)
## BEGIN_GRAPH
IR {
%0 = s64[] xla::device_data(), xla_shape=s64[], device=SPMD:0
%1 = s64[1]{0} aten::as_strided(%0), xla_shape=s64[1]{0}, size=(1), stride=(1), storage_offset=0, ROOT=0
}
Ah ok found it, the additional execution is caused by
aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset)
HloModule IrToHlo.4, entry_computation_layout={(s64[])->(s64[1]{0})}
ENTRY %IrToHlo.4 (p0.1: s64[]) -> (s64[1]) {
%p0.1 = s64[] parameter(0), sharding={replicated}, metadata={op_type="xla__device_data" op_name="xla__device_data" source_file="<listcomp>@dynamo_bridge.py" source_line=409}
%reshape.2 = s64[1]{0} reshape(s64[] %p0.1), metadata={op_type="aten__as_strided" op_name="aten__as_strided" source_file="gen_alias_from_base@aot_autograd.py" source_line=669}
ROOT %tuple.3 = (s64[1]{0}) tuple(s64[1]{0} %reshape.2)
}
@bdhirsh these as_strided
call are problematic for xla.. I guess both for inference and training. Do you know what are they for?
The fundermental issue here is pytorch assume they can modify the tensors between graph executions but for torch_xla, any small modification result in to graph execution(since we can't just inplace update anything in XLA device, we always execute a new computation and assign the value to the original C++ tensor).
So, there are a handful of (hopefully rare) places where as_strided()
is called in AOTAutograd.
In the limit, ideally we'd probably want to replace every as_strided()
call with a corresponding chain of "simple" view ops that can get the same behavior. But this will be a decent chunk of work, and I'm not sure that anyone will have time to try to do this any time soon. In this particular case, it looks like AOTAutograd made a best-effort attempt to re-use autograd's existing view-replay mechanism (code), but wasn't able to (probably because this condition failed?)
@JackCaoG I'm wondering if as an alternative to unblock - if you inspect that as_strided()
call and look at the sizes/strides passed in, they're probably pretty simple. Could XLA update their lowering for as_strided()
to detect when the size/stride passed in corresponds to a "simple" view op (e.g. the size is the same as the input, and the strides are just contiguous strides), and manually call the view lowering?
Thanks Brain. The issue is not necessary that as_strided
call is complex, but this is an op not being captured in the fx graph(it directly called on the input tensor). This forced us to execute an additioanl graph to materalize the input before running the dynamo graph which slows down the inference,
Ah I see. Unfortunately, AOTAutograd can't make any guarantees that every op will be part of the FX graph that we send to fw_compiler
/bw_compiler
- to ensure that all edge cases can be handled properly, AOTAutograd will sometimes generate a "runtime epilogue" that gets run after the compiled graph.
There are some details here, but one example is that if an output aliases an input, we can't just blindly put it in the graph - in order to satisfy the autograd engine, we need to re-run the view outside of the custom autograd.Function
object that we generated.
Just brainstorming, but I wonder - If XLA sees that the "second" graph that they trace out only consists of view ops (as it sounds like is the case in this example) - would it be possible to skip all the expensive compilation steps and just immediately return the correct view? Since no real compute is happening if all you have is a bunch of view ops in the graph.
@JackCaoG @bdhirsh I actually have thought of something that could be a solution for this issue.
After https://github.com/pytorch/pytorch/pull/121007 got merged, we actually have access to the chain of view functions that should be applied to the input. PyTorch currently applies that to the input. As @JackCaoG mentioned, that would still cause graph compilation+execution.
You are probably aware that PyTorch/XLA doesn't actually have the notion of tensor views. It just materializes an entire new tensor. In order to support it, FunctionalTensorWrapper
was created. Currently, we apply the view operations on the input, so as to get an output FunctionalTensorWrapper
that has input
(or whatever is the base of input
) as its base.
Here's my idea: we don't actually need to re-run the view operations, again. We have everything we need for creating a new FunctionalTensorWrapper
(i.e. XLA tensor) while maintaining the aliasing relation I mentioned above. Let i
be our input, out
be the concrete output with incorrect aliasing information, and fout
be the FunctionalTensorWrapper
for the output we got in the functionalization step:
FunctionalTensorWrapper(
base=i.base,
view_metas=i.view_metas + fout.view_metas,
value=out.value,
)
Let me know what you think.
I guess I am not the expert of the FunctionalTensorWrapper
. Requirement on pytorch/xla side is that input data is a DeviceData not some kind of pending executing(view ops). How is your proposal going to accomplish that? I assume the intention of the logic in aot-autograd is to compute the alias_out
which is a of different shape of the alias_base
, how can we get the alias_out
as a data without executing any computation on the alias_base
?
Currently, this is the signature of the function (source):
def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires_grad, target_functional_tensor=None):
aliased_base_tensor
is our inputtarget_meta_tensor
is the concrete output that came out of OpenXLA graph executiontarget_functional_tensor
is the FunctionalTensorWrapper
we got in the functionalization stepTo answer your questions:
Requirement on pytorch/xla side is that input data is a DeviceData not some kind of pending executing(view ops). How is your proposal going to accomplish that?
My proposal comes from the observation that we've already got the output we want. So, no need to execute any view operations.
how can we get the
alias_out
as a data without executing any computation on thealias_base
?
Because we have already run the computation. And the result of that is (2). What PyTorch is doing in this function is trying to reconstruct the aliasing relation between the actual output and (1).
Here's an example: suppose we are compiling the following function:
def foo(inp):
return inp.expand(2, -1)
Why is view reconstruction needed?
Now, assume that inp
is a tensor of shape [10]
. After going through the dynamo bridge, we would have a new tensor of shape [2, 10]
. However, without the view reconstruction, that output would have no aliasing relations with inp
. That's because the compiled graph doesn't go through the FunctionalTensorWrapper
layer. Thus, losing tensor view information. For example, we would expect that an in-place operation on the output would also change inp
. Without the view reconstruction procedure, that's not the case.
Since we know that the output is a view of the input we can just re-run the view operations we recorded (outside the compiled graph) on the functionalization step onto inp
. That will give us the correct aliased output.
Why don't we need to re-run the view operations?
Because PyTorch/XLA does not have the notion of tensor views. It only has disjoint tensors wrapped with FunctionalTensorWrapper
, which emulates aliasing. Therefore, we don't actually need to reconstruct the view. We need to wrap the output we got from the compiled graph into a FunctionalTensorWrapper
, so that its base is that of inp
. That would be us saying: "I have this completely unrelated tensor (output), which should be actually be a view of inp
".
@bdhirsh What do you think about this proposal?
@ysiraichi I think the main thing I don't get from your proposal is that there's another reason we re-run views outside of the graph:
@torch.compile(backend='aot_eager')
def f(x):
out = x.view(-1)
x = torch.ones(2, requires_grad=True).clone()
out = f(x)
out.mul_(2)
The above code should work, without raising an error. However, if we were to return the view generated inside of the graph directly back to the user, we would end up with the following:
(1) the compiled graph is wrapped in a giant autograd.Function
, and the autograd.Function.forward()
invokes the compiled graph and returns its output
(2) The autograd.Function will see that it is returning an output that aliases one of the inputs
(3) Autograd will raise an error if the user later tries to mutate that output. It does this because autograd's handling for aliasing + mutation ordinarily involves view-replay: autograd will regenerate the "mutated" tensor by replaying all of the views in the view chain. But in this case, one of those views came from an autograd.Function
, which might have side effects, so autograd bans this behavior. Instead, we regenerate the view outside of the autograd.Function
: this way autograd can see that it is a plain view op that creates the view, so it can properly do the same view-replay that would normally run in eager mode.
I guess an alternative is that we don't do the view-replay outside the graph for XLA and we say that we're ok with not supporting this view + mutation case, although it would still feel pretty bad to make the view regeneration logic in AOTAutograd conditional, since that increases the testing surface.
@bdhirsh Oh, I didn't know about that. One question, though: what kind of side effects could mutating an output view tensor have? Is this something that never happens when using AOTAutograd?
@bdhirsh After internal discussion with @JackCaoG, he mentioned that maybe we could do that on contexts where we know we won't use autograd, e.g. no_grad
or inference_mode
. What do you think?
Oh, I didn't know about that. One question, though: what kind of side effects could mutating an output view tensor have? Is this something that never happens when using AOTAutograd?
@bdhirsh Do you have any thoughts on this?
(2) The autograd.Function will see that it is returning an output that aliases one of the inputs
Will it, though? I'm asking because even though we are returning something computed inside an autograd.Function
, it's wrapped in a fresh FunctionalTensorWrapper
, created outside the autograd.Function
.
What do you think?
@ysiraichi hmm - yeah, I think it would be relatively reasonable to "only replay the view ops outside of the graph in that wrapper" if we are going down the training path of AOTAutograd.
one thing to note - @jamesjwu is doing a lot of refactoring of the runtime wrappers as part of adding a warm cache to AOTAutograd, so there will be some risk of merge conflicts in the next few weeks there.
One question, though: what kind of side effects could mutating an output view tensor have? Is this something that never happens when using AOTAutograd?
It's more just the fact that this is "allowed" today, so if the view op below happens to be compiled, we need it to work:
a = torch.randn(4, requires_grad=True).clone()
a_view = a.view(-1)
print(a_view.grad_fn)
# prints <ViewBackward0 object at 0x7fd2b331be20>
a_view.mul_(2)
print(a_view.grad_fn)
# prints <AsStridedBackward0 object at 0x7fd2b331be20>
🐛 Bug
It looks like
torchxla_trace_once
is deprecated in favor ofopenxla
, but when I tried to make that migration in some benchmark testing I saw a new warning message and some performance regressions. This was found when running an inference benchmark from openxla-benchmark - ResNet on GPU.To Reproduce
Colab repro.
Steps to reproduce the behavior:
Hopefully that provides enough information to be useful, if not I am happy to help further.
Expected behavior
On-par performance and HLO graph generation between the two backends (
openxla
andtorchxla_trace_once
).Environment
Additional context
Output traces: save_ir.zip