Open yuanyao-nv opened 1 month ago
I also tried to get more info on which part of dynamo/onnxscript might be responsible for this. If I run
scripted_model = torch.jit.script(model)
print(scripted_model.graph)
I get this:
graph(%self : __torch__.monai.networks.nets.segresnet.SegResNet,
%x.1 : Tensor):
%3 : (Tensor, Tensor[]) = prim::CallMethod[name="encode"](%self, %x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:180:20
%x.5 : Tensor, %down_x.1 : Tensor[] = prim::TupleUnpack(%3)
= aten::reverse(%down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:181:8
%x.9 : Tensor = prim::CallMethod[name="decode"](%self, %x.5, %down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:183:12
return (%x.9)
If I run
gm, _ = torch._dynamo.export(model)(data)
gm = torch.fx.experimental.proxy_tensor.make_fx(torch.func.functionalize(gm))(data)
gm.print_readable()
I get an error:
Traceback (most recent call last):
File "/ws/dynamo/0501/export_SegResNet.py", line 31, in <module>
gm, _ = torch._dynamo.export(model)(data)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1282, in inner
dim_constraints.solve()
File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 1772, in solve
tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
KeyError: "L['x'].size()[4]"
Thanks for catching this. Very intriguing. Will take a look!
cc @xiaowuhu @fatcat-z
@yuanyao-nv could you obtain the graph module from torch.export.export
and post it here?
@justinchuby Is this what you mean?
exported_program = torch.export.export(model, (data,))
exported_program._graph_module.print_readable()
which gives
class GraphModule(torch.nn.Module):
def forward(self, p_convinit_conv_weight: "f32[16, 4, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm1_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm1_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm2_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm2_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_down_layers_1_0_conv_weight: "f32[32, 16, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_down_layers_2_0_conv_weight: "f32[64, 32, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_down_layers_3_0_conv_weight: "f32[128, 64, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_up_samples_0_0_conv_weight: "f32[64, 128, 1, 1, 1]", p_getattr_l__self___up_layers_0___0___norm1_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm1_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___up_layers_0___0___norm2_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm2_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_up_samples_1_0_conv_weight: "f32[32, 64, 1, 1, 1]", p_getattr_l__self___up_layers_1___0___norm1_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm1_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___up_layers_1___0___norm2_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm2_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_up_samples_2_0_conv_weight: "f32[16, 32, 1, 1, 1]", p_getattr_l__self___up_layers_2___0___norm1_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm1_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___up_layers_2___0___norm2_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm2_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_conv_final_0_weight: "f32[16]", p_conv_final_0_bias: "f32[16]", p_conv_final_2_conv_weight: "f32[3, 16, 1, 1, 1]", p_conv_final_2_conv_bias: "f32[3]", x: "f32[1, 4, 224, 224, 128]"):
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:157 in encode, code: x = self.convInit(x)
conv3d: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(x, p_convinit_conv_weight, None, [1, 1, 1], [1, 1, 1]); x = p_convinit_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:159 in encode, code: x = self.dropout(x)
feature_dropout: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.feature_dropout.default(conv3d, 0.2, False); conv3d = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(feature_dropout, 8, p_getattr_l__self___down_layers_0___1___norm1_weight, p_getattr_l__self___down_layers_0___1___norm1_bias); p_getattr_l__self___down_layers_0___1___norm1_weight = p_getattr_l__self___down_layers_0___1___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm); group_norm = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu, p_getattr_l__self___down_layers_0___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu = p_getattr_l__self___down_layers_0___1___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_1, 8, p_getattr_l__self___down_layers_0___1___norm2_weight, p_getattr_l__self___down_layers_0___1___norm2_bias); conv3d_1 = p_getattr_l__self___down_layers_0___1___norm2_weight = p_getattr_l__self___down_layers_0___1___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_1); group_norm_1 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_1, p_getattr_l__self___down_layers_0___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_1 = p_getattr_l__self___down_layers_0___1___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_2, feature_dropout); conv3d_2 = feature_dropout = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
conv3d_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(add, p_down_layers_1_0_conv_weight, None, [2, 2, 2], [1, 1, 1]); p_down_layers_1_0_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_3, 8, p_getattr_l__self___down_layers_1___1___norm1_weight, p_getattr_l__self___down_layers_1___1___norm1_bias); p_getattr_l__self___down_layers_1___1___norm1_weight = p_getattr_l__self___down_layers_1___1___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_2); group_norm_2 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_2, p_getattr_l__self___down_layers_1___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_2 = p_getattr_l__self___down_layers_1___1___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_4, 8, p_getattr_l__self___down_layers_1___1___norm2_weight, p_getattr_l__self___down_layers_1___1___norm2_bias); conv3d_4 = p_getattr_l__self___down_layers_1___1___norm2_weight = p_getattr_l__self___down_layers_1___1___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_3); group_norm_3 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_3, p_getattr_l__self___down_layers_1___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_3 = p_getattr_l__self___down_layers_1___1___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_5, conv3d_3); conv3d_5 = conv3d_3 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_1, 8, p_getattr_l__self___down_layers_1___2___norm1_weight, p_getattr_l__self___down_layers_1___2___norm1_bias); p_getattr_l__self___down_layers_1___2___norm1_weight = p_getattr_l__self___down_layers_1___2___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_4); group_norm_4 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_6: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_4, p_getattr_l__self___down_layers_1___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_4 = p_getattr_l__self___down_layers_1___2___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_6, 8, p_getattr_l__self___down_layers_1___2___norm2_weight, p_getattr_l__self___down_layers_1___2___norm2_bias); conv3d_6 = p_getattr_l__self___down_layers_1___2___norm2_weight = p_getattr_l__self___down_layers_1___2___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_5); group_norm_5 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_7: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_5, p_getattr_l__self___down_layers_1___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_5 = p_getattr_l__self___down_layers_1___2___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_7, add_1); conv3d_7 = add_1 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
conv3d_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_2, p_down_layers_2_0_conv_weight, None, [2, 2, 2], [1, 1, 1]); p_down_layers_2_0_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_8, 8, p_getattr_l__self___down_layers_2___1___norm1_weight, p_getattr_l__self___down_layers_2___1___norm1_bias); p_getattr_l__self___down_layers_2___1___norm1_weight = p_getattr_l__self___down_layers_2___1___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_6); group_norm_6 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_6, p_getattr_l__self___down_layers_2___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_6 = p_getattr_l__self___down_layers_2___1___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_9, 8, p_getattr_l__self___down_layers_2___1___norm2_weight, p_getattr_l__self___down_layers_2___1___norm2_bias); conv3d_9 = p_getattr_l__self___down_layers_2___1___norm2_weight = p_getattr_l__self___down_layers_2___1___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_7); group_norm_7 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_7, p_getattr_l__self___down_layers_2___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_7 = p_getattr_l__self___down_layers_2___1___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_3: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_10, conv3d_8); conv3d_10 = conv3d_8 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_3, 8, p_getattr_l__self___down_layers_2___2___norm1_weight, p_getattr_l__self___down_layers_2___2___norm1_bias); p_getattr_l__self___down_layers_2___2___norm1_weight = p_getattr_l__self___down_layers_2___2___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_8); group_norm_8 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_11: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_8, p_getattr_l__self___down_layers_2___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_8 = p_getattr_l__self___down_layers_2___2___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_11, 8, p_getattr_l__self___down_layers_2___2___norm2_weight, p_getattr_l__self___down_layers_2___2___norm2_bias); conv3d_11 = p_getattr_l__self___down_layers_2___2___norm2_weight = p_getattr_l__self___down_layers_2___2___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_9); group_norm_9 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_12: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_9, p_getattr_l__self___down_layers_2___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_9 = p_getattr_l__self___down_layers_2___2___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_4: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_12, add_3); conv3d_12 = add_3 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
conv3d_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_4, p_down_layers_3_0_conv_weight, None, [2, 2, 2], [1, 1, 1]); p_down_layers_3_0_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_13, 8, p_getattr_l__self___down_layers_3___1___norm1_weight, p_getattr_l__self___down_layers_3___1___norm1_bias); p_getattr_l__self___down_layers_3___1___norm1_weight = p_getattr_l__self___down_layers_3___1___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_10); group_norm_10 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_10, p_getattr_l__self___down_layers_3___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_10 = p_getattr_l__self___down_layers_3___1___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_14, 8, p_getattr_l__self___down_layers_3___1___norm2_weight, p_getattr_l__self___down_layers_3___1___norm2_bias); conv3d_14 = p_getattr_l__self___down_layers_3___1___norm2_weight = p_getattr_l__self___down_layers_3___1___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_11); group_norm_11 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_11, p_getattr_l__self___down_layers_3___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_11 = p_getattr_l__self___down_layers_3___1___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_5: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_15, conv3d_13); conv3d_15 = conv3d_13 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_5, 8, p_getattr_l__self___down_layers_3___2___norm1_weight, p_getattr_l__self___down_layers_3___2___norm1_bias); p_getattr_l__self___down_layers_3___2___norm1_weight = p_getattr_l__self___down_layers_3___2___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_12); group_norm_12 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_12, p_getattr_l__self___down_layers_3___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_12 = p_getattr_l__self___down_layers_3___2___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_16, 8, p_getattr_l__self___down_layers_3___2___norm2_weight, p_getattr_l__self___down_layers_3___2___norm2_bias); conv3d_16 = p_getattr_l__self___down_layers_3___2___norm2_weight = p_getattr_l__self___down_layers_3___2___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_13); group_norm_13 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_13, p_getattr_l__self___down_layers_3___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_13 = p_getattr_l__self___down_layers_3___2___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_6: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_17, add_5); conv3d_17 = add_5 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_6, 8, p_getattr_l__self___down_layers_3___3___norm1_weight, p_getattr_l__self___down_layers_3___3___norm1_bias); p_getattr_l__self___down_layers_3___3___norm1_weight = p_getattr_l__self___down_layers_3___3___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_14); group_norm_14 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_18: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_14, p_getattr_l__self___down_layers_3___3___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_14 = p_getattr_l__self___down_layers_3___3___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_18, 8, p_getattr_l__self___down_layers_3___3___norm2_weight, p_getattr_l__self___down_layers_3___3___norm2_bias); conv3d_18 = p_getattr_l__self___down_layers_3___3___norm2_weight = p_getattr_l__self___down_layers_3___3___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_15); group_norm_15 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_19: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_15, p_getattr_l__self___down_layers_3___3___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_15 = p_getattr_l__self___down_layers_3___3___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_7: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_19, add_6); conv3d_19 = add_6 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_7, 8, p_getattr_l__self___down_layers_3___4___norm1_weight, p_getattr_l__self___down_layers_3___4___norm1_bias); p_getattr_l__self___down_layers_3___4___norm1_weight = p_getattr_l__self___down_layers_3___4___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_16); group_norm_16 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_20: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_16, p_getattr_l__self___down_layers_3___4___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_16 = p_getattr_l__self___down_layers_3___4___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_20, 8, p_getattr_l__self___down_layers_3___4___norm2_weight, p_getattr_l__self___down_layers_3___4___norm2_bias); conv3d_20 = p_getattr_l__self___down_layers_3___4___norm2_weight = p_getattr_l__self___down_layers_3___4___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_17); group_norm_17 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_21: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_17, p_getattr_l__self___down_layers_3___4___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_17 = p_getattr_l__self___down_layers_3___4___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_8: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_21, add_7); conv3d_21 = add_7 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
conv3d_22: "f32[1, 64, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_8, p_up_samples_0_0_conv_weight); add_8 = p_up_samples_0_0_conv_weight = None
upsample_trilinear3d: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_22, None, False, [2.0, 2.0, 2.0]); conv3d_22 = None
add_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(upsample_trilinear3d, add_4); upsample_trilinear3d = add_4 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_9, 8, p_getattr_l__self___up_layers_0___0___norm1_weight, p_getattr_l__self___up_layers_0___0___norm1_bias); p_getattr_l__self___up_layers_0___0___norm1_weight = p_getattr_l__self___up_layers_0___0___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_18); group_norm_18 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_23: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_18, p_getattr_l__self___up_layers_0___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_18 = p_getattr_l__self___up_layers_0___0___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_23, 8, p_getattr_l__self___up_layers_0___0___norm2_weight, p_getattr_l__self___up_layers_0___0___norm2_bias); conv3d_23 = p_getattr_l__self___up_layers_0___0___norm2_weight = p_getattr_l__self___up_layers_0___0___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_19); group_norm_19 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_24: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_19, p_getattr_l__self___up_layers_0___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_19 = p_getattr_l__self___up_layers_0___0___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_24, add_9); conv3d_24 = add_9 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
conv3d_25: "f32[1, 32, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_10, p_up_samples_1_0_conv_weight); add_10 = p_up_samples_1_0_conv_weight = None
upsample_trilinear3d_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_25, None, False, [2.0, 2.0, 2.0]); conv3d_25 = None
add_11: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_1, add_2); upsample_trilinear3d_1 = add_2 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_11, 8, p_getattr_l__self___up_layers_1___0___norm1_weight, p_getattr_l__self___up_layers_1___0___norm1_bias); p_getattr_l__self___up_layers_1___0___norm1_weight = p_getattr_l__self___up_layers_1___0___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_20); group_norm_20 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_26: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_20, p_getattr_l__self___up_layers_1___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_20 = p_getattr_l__self___up_layers_1___0___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_26, 8, p_getattr_l__self___up_layers_1___0___norm2_weight, p_getattr_l__self___up_layers_1___0___norm2_bias); conv3d_26 = p_getattr_l__self___up_layers_1___0___norm2_weight = p_getattr_l__self___up_layers_1___0___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_21); group_norm_21 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_27: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_21, p_getattr_l__self___up_layers_1___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_21 = p_getattr_l__self___up_layers_1___0___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_12: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_27, add_11); conv3d_27 = add_11 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
conv3d_28: "f32[1, 16, 112, 112, 64]" = torch.ops.aten.conv3d.default(add_12, p_up_samples_2_0_conv_weight); add_12 = p_up_samples_2_0_conv_weight = None
upsample_trilinear3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_28, None, False, [2.0, 2.0, 2.0]); conv3d_28 = None
add_13: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_2, add); upsample_trilinear3d_2 = add = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
group_norm_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_13, 8, p_getattr_l__self___up_layers_2___0___norm1_weight, p_getattr_l__self___up_layers_2___0___norm1_bias); p_getattr_l__self___up_layers_2___0___norm1_weight = p_getattr_l__self___up_layers_2___0___norm1_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
relu_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_22); group_norm_22 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
conv3d_29: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_22, p_getattr_l__self___up_layers_2___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_22 = p_getattr_l__self___up_layers_2___0___conv1_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
group_norm_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_29, 8, p_getattr_l__self___up_layers_2___0___norm2_weight, p_getattr_l__self___up_layers_2___0___norm2_bias); conv3d_29 = p_getattr_l__self___up_layers_2___0___norm2_weight = p_getattr_l__self___up_layers_2___0___norm2_bias = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
relu_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_23); group_norm_23 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
conv3d_30: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_23, p_getattr_l__self___up_layers_2___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]); relu_23 = p_getattr_l__self___up_layers_2___0___conv2_conv_weight = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
add_14: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_30, add_13); conv3d_30 = add_13 = None
# File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:175 in decode, code: x = self.conv_final(x)
group_norm_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_14, 8, p_conv_final_0_weight, p_conv_final_0_bias); add_14 = p_conv_final_0_weight = p_conv_final_0_bias = None
relu_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_24); group_norm_24 = None
conv3d_31: "f32[1, 3, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_24, p_conv_final_2_conv_weight, p_conv_final_2_conv_bias); relu_24 = p_conv_final_2_conv_weight = p_conv_final_2_conv_bias = None
return (conv3d_31,)
Yes, thank you
That's very strange. If you run torch.onnx.dynamo_export(exported_program, ...)
, do you get the same graph?
@justinchuby Is this the procedure you're suggesting?
exported_program = torch.export.export(model, (data,))
exported_program._graph_module.print_readable()
export_output = torch.onnx.dynamo_export(exported_program, data)
export_output.save('Clara_SegResNet_dynamo1.onnx')
The exported UpSample module looks about the same as before, still a very big graph. In addition, the weights in the model appear as extra inputs, giving rise to tens of extra model inputs. Similar to this bug https://github.com/pytorch/pytorch/issues/126071
I don't see any resize ops, which is puzzling. Could you share the onnx model? You may remove the weights if it is too big
I was expecting to see this function: https://github.com/microsoft/onnxscript/blob/2b6dc27b34f2e4e9fc4c3ad73635c5b157a4c714/onnxscript/function_libs/torch_lib/ops/nn.py#L2215
@justinchuby I uploaded the two versions of the model here: https://drive.google.com/drive/folders/1s1lhKRuG6fOZmD4IjZvN_zlWfIxPB_8w?usp=sharing
It's possible that the upsample op was somehow decomposed by PyTorch. I will look deeper.
I have found that in general case, one has to run exported_program.run_decompositions() before applying dynamo_export(). That may in fact fold some operations. @yuanyao-nv can you try that ?
Thanks. We will be creating a series of changes to the exporter to support ExportedPrograms properly, including handling of the weights.
@borisfom I tried running run_decompositions()
but it didn't do anything for this particular subgraph.
exported_program = torch.export.export(model, args=(data,))
exported_program.run_decompositions()
export_output = torch.onnx.dynamo_export(exported_program, data)
export_output.save('Clara_SegResNet_dynamo.onnx')
There are two issues here:
upsample_trilinear_vec
op: #1592 Hi @yuanyao-nv,
This one should be fixed when you call torch.onnx.dynamo_export
with nn.Module. However, if you call torch.export.export first, it's going to be decomposed to the big subgraph you had. This decomposition is forced by dynamo for some reasons. Feel free to open an issue like https://github.com/pytorch/pytorch/issues/115883.
cc @gramalingam @justinchuby @xadupre This forcing decomposition would need us to maybe rewriting them as patterns. It will come back to us once we rely on torch.export.export.
@titaiwangms Thanks for the update.
What's a good way to test it out?
If I rerun the export script in the description using the latest torch nightly build (2.5.0.dev20240617+cu121
) I actually hit another error
Traceback (most recent call last):
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1509, in dynamo_export
).export()
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1236, in export
graph_module = self.options.fx_tracer.generate_fx(
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 214, in generate_fx
graph_module, graph_guard = torch._dynamo.export(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
result_traced = opt_f(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 169, in wrapped
return output_adapter.apply(model_func(*args, **kwargs), model=model)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
return self._torchdynamo_orig_callable(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
return _compile(
File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
return StrobelightCompileTimeProfiler.profile_compile_time(
File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
return func(*args, **kwargs)
File "/usr/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2462, in run
super().run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
while self.step():
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
return inner_fn(self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1470, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 356, in call_function
return super().call_function(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 298, in call_function
return super().call_function(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 95, in call_function
return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2677, in inline_call
return cls.inline_call_(parent, func, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2793, in inline_call_
tracer.run()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
while self.step():
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
self.dispatch_table[inst.opcode](self, inst)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 234, in impl
self.push(fn_var.call_function(self, self.popn(nargs), {}))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 962, in call_function
return handler(tx, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 941, in _handle_insert_op_in_graph
return wrap_fx_proxy(tx, proxy)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1759, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1846, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value
ret_val = wrap_fake_exception(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception
return fn()
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1786, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1903, in run_node
return node.target(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
return self._cached_dispatch_impl(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1145, in _cached_dispatch_impl
output = self._dispatch_impl(func, types, args, kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1757, in _dispatch_impl
r = func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 670, in __call__
return self_._op(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 266, in _fn
result = fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 138, in _fn
result = fn(**bound.arguments)
File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 1080, in add
a, b = _maybe_broadcast(a, b)
File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 419, in _maybe_broadcast
common_shape = _broadcast_shapes(
File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 408, in _broadcast_shapes
raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='cuda:0', size=(1, 64, 28, 56, 56, 32),
grad_fn=<WarnNotImplemented>), FakeTensor(..., device='cuda:0', size=(1, 64, 56, 56, 32),
grad_fn=<AddBackward0>)), **{}):
Attempting to broadcast a dimension of length 64 at -4! Mismatching argument at index 1 had torch.Size([1, 64, 56, 56, 32]); but expected shape should be broadcastable to [1, 64, 28, 56, 56, 32]
from user code:
File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 183, in forward
x = self.decode(x, down_x)
File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 171, in decode
x = up(x) + down_x[i + 1]
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
Do you know why fake tensor is being used in the latest torch version?
I also tried exporting just a nn.Upsample
function
def f(x):
m = torch.nn.Upsample(size=(10), mode='linear')
return m(x)
x = torch.randn(2, 5, 5)
export_output = torch.onnx.dynamo_export(f, x)
export_output.save('Upsample.onnx')
The exported graph looks reasonable. Is this what you'd expect?
Filed a separate issue to track the above fake tensor broadcast error https://github.com/pytorch/pytorch/issues/129534
I'm looking at some models in MONAI which involves
torch.nn.Upsample
. I notice that torchscript exports the Upsample module to aResize
node but dynamo exports it to a very big graph and has a perf impact.An example is the SegResNet model.
torchscript:
dynamo:
expanding the dynamo subgraph:
Here's the export script:
Relevant versions: onnx==1.16.0 onnxscript==0.1.0.dev20240513 torch==2.4.0.dev20240513+cu121 monai 1.3.0