microsoft / onnxscript

ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python.
https://onnxscript.ai/
MIT License
246 stars 46 forks source link

Upsample exports to a very big subgraph in dynamo #1533

Open yuanyao-nv opened 1 month ago

yuanyao-nv commented 1 month ago

I'm looking at some models in MONAI which involves torch.nn.Upsample. I notice that torchscript exports the Upsample module to a Resize node but dynamo exports it to a very big graph and has a perf impact.

An example is the SegResNet model.

torchscript:

image

dynamo:

image

expanding the dynamo subgraph:

image

Here's the export script:

import torch
from monai.networks.nets import SegResNet

model = lambda : SegResNet(
    blocks_down=(1, 2, 2, 4),
    blocks_up=(1, 1, 1),
    init_filters=16,
    in_channels=4,
    out_channels=3,
    dropout_prob=0.2
)
model = model().eval().to('cuda')
data = torch.randn(1,4,224,224,128).to("cuda")

dynamo_export = True
if dynamo_export:
    export_output = torch.onnx.dynamo_export(
        model,
        data,
    )
    export_output.save('Clara_SegResNet_dynamo.onnx')
else:
    torch.onnx.export(model, (data,), 'Clara_SegResNet.onnx')

Relevant versions: onnx==1.16.0 onnxscript==0.1.0.dev20240513 torch==2.4.0.dev20240513+cu121 monai 1.3.0

yuanyao-nv commented 1 month ago

I also tried to get more info on which part of dynamo/onnxscript might be responsible for this. If I run

scripted_model = torch.jit.script(model)
print(scripted_model.graph)

I get this:

graph(%self : __torch__.monai.networks.nets.segresnet.SegResNet,
      %x.1 : Tensor):
  %3 : (Tensor, Tensor[]) = prim::CallMethod[name="encode"](%self, %x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:180:20
  %x.5 : Tensor, %down_x.1 : Tensor[] = prim::TupleUnpack(%3)
   = aten::reverse(%down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:181:8
  %x.9 : Tensor = prim::CallMethod[name="decode"](%self, %x.5, %down_x.1) # /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:183:12
  return (%x.9)

If I run

gm, _ = torch._dynamo.export(model)(data)
gm = torch.fx.experimental.proxy_tensor.make_fx(torch.func.functionalize(gm))(data)
gm.print_readable()

I get an error:

Traceback (most recent call last):
  File "/ws/dynamo/0501/export_SegResNet.py", line 31, in <module>
    gm, _ = torch._dynamo.export(model)(data)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1282, in inner
    dim_constraints.solve()
  File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 1772, in solve
    tmp_name = f"_{self._dcp.source_name_to_debug_name[self._dcp.symbol_to_source[s][0].name()]}"
KeyError: "L['x'].size()[4]"
justinchuby commented 1 month ago

Thanks for catching this. Very intriguing. Will take a look!

justinchuby commented 1 month ago

cc @xiaowuhu @fatcat-z

justinchuby commented 1 month ago

@yuanyao-nv could you obtain the graph module from torch.export.export and post it here?

yuanyao-nv commented 1 month ago

@justinchuby Is this what you mean?

    exported_program = torch.export.export(model, (data,))
    exported_program._graph_module.print_readable()

which gives

class GraphModule(torch.nn.Module):
    def forward(self, p_convinit_conv_weight: "f32[16, 4, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm1_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm1_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___down_layers_0___1___norm2_weight: "f32[16]", p_getattr_l__self___down_layers_0___1___norm2_bias: "f32[16]", p_getattr_l__self___down_layers_0___1___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_down_layers_1_0_conv_weight: "f32[32, 16, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___1___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___1___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___1___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm1_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm1_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___down_layers_1___2___norm2_weight: "f32[32]", p_getattr_l__self___down_layers_1___2___norm2_bias: "f32[32]", p_getattr_l__self___down_layers_1___2___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_down_layers_2_0_conv_weight: "f32[64, 32, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___1___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___1___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___1___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm1_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm1_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___down_layers_2___2___norm2_weight: "f32[64]", p_getattr_l__self___down_layers_2___2___norm2_bias: "f32[64]", p_getattr_l__self___down_layers_2___2___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_down_layers_3_0_conv_weight: "f32[128, 64, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___1___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___1___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___1___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___2___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___2___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___2___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___3___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___3___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___3___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm1_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm1_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv1_conv_weight: "f32[128, 128, 3, 3, 3]", p_getattr_l__self___down_layers_3___4___norm2_weight: "f32[128]", p_getattr_l__self___down_layers_3___4___norm2_bias: "f32[128]", p_getattr_l__self___down_layers_3___4___conv2_conv_weight: "f32[128, 128, 3, 3, 3]", p_up_samples_0_0_conv_weight: "f32[64, 128, 1, 1, 1]", p_getattr_l__self___up_layers_0___0___norm1_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm1_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv1_conv_weight: "f32[64, 64, 3, 3, 3]", p_getattr_l__self___up_layers_0___0___norm2_weight: "f32[64]", p_getattr_l__self___up_layers_0___0___norm2_bias: "f32[64]", p_getattr_l__self___up_layers_0___0___conv2_conv_weight: "f32[64, 64, 3, 3, 3]", p_up_samples_1_0_conv_weight: "f32[32, 64, 1, 1, 1]", p_getattr_l__self___up_layers_1___0___norm1_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm1_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv1_conv_weight: "f32[32, 32, 3, 3, 3]", p_getattr_l__self___up_layers_1___0___norm2_weight: "f32[32]", p_getattr_l__self___up_layers_1___0___norm2_bias: "f32[32]", p_getattr_l__self___up_layers_1___0___conv2_conv_weight: "f32[32, 32, 3, 3, 3]", p_up_samples_2_0_conv_weight: "f32[16, 32, 1, 1, 1]", p_getattr_l__self___up_layers_2___0___norm1_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm1_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv1_conv_weight: "f32[16, 16, 3, 3, 3]", p_getattr_l__self___up_layers_2___0___norm2_weight: "f32[16]", p_getattr_l__self___up_layers_2___0___norm2_bias: "f32[16]", p_getattr_l__self___up_layers_2___0___conv2_conv_weight: "f32[16, 16, 3, 3, 3]", p_conv_final_0_weight: "f32[16]", p_conv_final_0_bias: "f32[16]", p_conv_final_2_conv_weight: "f32[3, 16, 1, 1, 1]", p_conv_final_2_conv_bias: "f32[3]", x: "f32[1, 4, 224, 224, 128]"):
        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:157 in encode, code: x = self.convInit(x)
        conv3d: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(x, p_convinit_conv_weight, None, [1, 1, 1], [1, 1, 1]);  x = p_convinit_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:159 in encode, code: x = self.dropout(x)
        feature_dropout: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.feature_dropout.default(conv3d, 0.2, False);  conv3d = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(feature_dropout, 8, p_getattr_l__self___down_layers_0___1___norm1_weight, p_getattr_l__self___down_layers_0___1___norm1_bias);  p_getattr_l__self___down_layers_0___1___norm1_weight = p_getattr_l__self___down_layers_0___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm);  group_norm = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu, p_getattr_l__self___down_layers_0___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu = p_getattr_l__self___down_layers_0___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_1, 8, p_getattr_l__self___down_layers_0___1___norm2_weight, p_getattr_l__self___down_layers_0___1___norm2_bias);  conv3d_1 = p_getattr_l__self___down_layers_0___1___norm2_weight = p_getattr_l__self___down_layers_0___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_1: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_1);  group_norm_1 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_1, p_getattr_l__self___down_layers_0___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_1 = p_getattr_l__self___down_layers_0___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_2, feature_dropout);  conv3d_2 = feature_dropout = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
        conv3d_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(add, p_down_layers_1_0_conv_weight, None, [2, 2, 2], [1, 1, 1]);  p_down_layers_1_0_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_3, 8, p_getattr_l__self___down_layers_1___1___norm1_weight, p_getattr_l__self___down_layers_1___1___norm1_bias);  p_getattr_l__self___down_layers_1___1___norm1_weight = p_getattr_l__self___down_layers_1___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_2);  group_norm_2 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_2, p_getattr_l__self___down_layers_1___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_2 = p_getattr_l__self___down_layers_1___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_4, 8, p_getattr_l__self___down_layers_1___1___norm2_weight, p_getattr_l__self___down_layers_1___1___norm2_bias);  conv3d_4 = p_getattr_l__self___down_layers_1___1___norm2_weight = p_getattr_l__self___down_layers_1___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_3: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_3);  group_norm_3 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_3, p_getattr_l__self___down_layers_1___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_3 = p_getattr_l__self___down_layers_1___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_5, conv3d_3);  conv3d_5 = conv3d_3 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_1, 8, p_getattr_l__self___down_layers_1___2___norm1_weight, p_getattr_l__self___down_layers_1___2___norm1_bias);  p_getattr_l__self___down_layers_1___2___norm1_weight = p_getattr_l__self___down_layers_1___2___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_4: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_4);  group_norm_4 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_6: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_4, p_getattr_l__self___down_layers_1___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_4 = p_getattr_l__self___down_layers_1___2___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_6, 8, p_getattr_l__self___down_layers_1___2___norm2_weight, p_getattr_l__self___down_layers_1___2___norm2_bias);  conv3d_6 = p_getattr_l__self___down_layers_1___2___norm2_weight = p_getattr_l__self___down_layers_1___2___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_5: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_5);  group_norm_5 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_7: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_5, p_getattr_l__self___down_layers_1___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_5 = p_getattr_l__self___down_layers_1___2___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_2: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_7, add_1);  conv3d_7 = add_1 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
        conv3d_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_2, p_down_layers_2_0_conv_weight, None, [2, 2, 2], [1, 1, 1]);  p_down_layers_2_0_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_8, 8, p_getattr_l__self___down_layers_2___1___norm1_weight, p_getattr_l__self___down_layers_2___1___norm1_bias);  p_getattr_l__self___down_layers_2___1___norm1_weight = p_getattr_l__self___down_layers_2___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_6: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_6);  group_norm_6 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_6, p_getattr_l__self___down_layers_2___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_6 = p_getattr_l__self___down_layers_2___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_9, 8, p_getattr_l__self___down_layers_2___1___norm2_weight, p_getattr_l__self___down_layers_2___1___norm2_bias);  conv3d_9 = p_getattr_l__self___down_layers_2___1___norm2_weight = p_getattr_l__self___down_layers_2___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_7: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_7);  group_norm_7 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_7, p_getattr_l__self___down_layers_2___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_7 = p_getattr_l__self___down_layers_2___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_3: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_10, conv3d_8);  conv3d_10 = conv3d_8 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_3, 8, p_getattr_l__self___down_layers_2___2___norm1_weight, p_getattr_l__self___down_layers_2___2___norm1_bias);  p_getattr_l__self___down_layers_2___2___norm1_weight = p_getattr_l__self___down_layers_2___2___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_8: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_8);  group_norm_8 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_11: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_8, p_getattr_l__self___down_layers_2___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_8 = p_getattr_l__self___down_layers_2___2___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_11, 8, p_getattr_l__self___down_layers_2___2___norm2_weight, p_getattr_l__self___down_layers_2___2___norm2_bias);  conv3d_11 = p_getattr_l__self___down_layers_2___2___norm2_weight = p_getattr_l__self___down_layers_2___2___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_9);  group_norm_9 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_12: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_9, p_getattr_l__self___down_layers_2___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_9 = p_getattr_l__self___down_layers_2___2___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_4: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_12, add_3);  conv3d_12 = add_3 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:164 in encode, code: x = down(x)
        conv3d_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_4, p_down_layers_3_0_conv_weight, None, [2, 2, 2], [1, 1, 1]);  p_down_layers_3_0_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_13, 8, p_getattr_l__self___down_layers_3___1___norm1_weight, p_getattr_l__self___down_layers_3___1___norm1_bias);  p_getattr_l__self___down_layers_3___1___norm1_weight = p_getattr_l__self___down_layers_3___1___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_10: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_10);  group_norm_10 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_10, p_getattr_l__self___down_layers_3___1___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_10 = p_getattr_l__self___down_layers_3___1___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_14, 8, p_getattr_l__self___down_layers_3___1___norm2_weight, p_getattr_l__self___down_layers_3___1___norm2_bias);  conv3d_14 = p_getattr_l__self___down_layers_3___1___norm2_weight = p_getattr_l__self___down_layers_3___1___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_11: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_11);  group_norm_11 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_11, p_getattr_l__self___down_layers_3___1___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_11 = p_getattr_l__self___down_layers_3___1___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_5: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_15, conv3d_13);  conv3d_15 = conv3d_13 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_5, 8, p_getattr_l__self___down_layers_3___2___norm1_weight, p_getattr_l__self___down_layers_3___2___norm1_bias);  p_getattr_l__self___down_layers_3___2___norm1_weight = p_getattr_l__self___down_layers_3___2___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_12: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_12);  group_norm_12 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_12, p_getattr_l__self___down_layers_3___2___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_12 = p_getattr_l__self___down_layers_3___2___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_16, 8, p_getattr_l__self___down_layers_3___2___norm2_weight, p_getattr_l__self___down_layers_3___2___norm2_bias);  conv3d_16 = p_getattr_l__self___down_layers_3___2___norm2_weight = p_getattr_l__self___down_layers_3___2___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_13: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_13);  group_norm_13 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_13, p_getattr_l__self___down_layers_3___2___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_13 = p_getattr_l__self___down_layers_3___2___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_6: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_17, add_5);  conv3d_17 = add_5 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_6, 8, p_getattr_l__self___down_layers_3___3___norm1_weight, p_getattr_l__self___down_layers_3___3___norm1_bias);  p_getattr_l__self___down_layers_3___3___norm1_weight = p_getattr_l__self___down_layers_3___3___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_14: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_14);  group_norm_14 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_18: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_14, p_getattr_l__self___down_layers_3___3___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_14 = p_getattr_l__self___down_layers_3___3___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_18, 8, p_getattr_l__self___down_layers_3___3___norm2_weight, p_getattr_l__self___down_layers_3___3___norm2_bias);  conv3d_18 = p_getattr_l__self___down_layers_3___3___norm2_weight = p_getattr_l__self___down_layers_3___3___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_15: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_15);  group_norm_15 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_19: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_15, p_getattr_l__self___down_layers_3___3___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_15 = p_getattr_l__self___down_layers_3___3___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_7: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_19, add_6);  conv3d_19 = add_6 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(add_7, 8, p_getattr_l__self___down_layers_3___4___norm1_weight, p_getattr_l__self___down_layers_3___4___norm1_bias);  p_getattr_l__self___down_layers_3___4___norm1_weight = p_getattr_l__self___down_layers_3___4___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_16: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_16);  group_norm_16 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_20: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_16, p_getattr_l__self___down_layers_3___4___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_16 = p_getattr_l__self___down_layers_3___4___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.group_norm.default(conv3d_20, 8, p_getattr_l__self___down_layers_3___4___norm2_weight, p_getattr_l__self___down_layers_3___4___norm2_bias);  conv3d_20 = p_getattr_l__self___down_layers_3___4___norm2_weight = p_getattr_l__self___down_layers_3___4___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_17: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.relu.default(group_norm_17);  group_norm_17 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_21: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.conv3d.default(relu_17, p_getattr_l__self___down_layers_3___4___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_17 = p_getattr_l__self___down_layers_3___4___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_8: "f32[1, 128, 28, 28, 16]" = torch.ops.aten.add.Tensor(conv3d_21, add_7);  conv3d_21 = add_7 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
        conv3d_22: "f32[1, 64, 28, 28, 16]" = torch.ops.aten.conv3d.default(add_8, p_up_samples_0_0_conv_weight);  add_8 = p_up_samples_0_0_conv_weight = None
        upsample_trilinear3d: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_22, None, False, [2.0, 2.0, 2.0]);  conv3d_22 = None
        add_9: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(upsample_trilinear3d, add_4);  upsample_trilinear3d = add_4 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(add_9, 8, p_getattr_l__self___up_layers_0___0___norm1_weight, p_getattr_l__self___up_layers_0___0___norm1_bias);  p_getattr_l__self___up_layers_0___0___norm1_weight = p_getattr_l__self___up_layers_0___0___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_18: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_18);  group_norm_18 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_23: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_18, p_getattr_l__self___up_layers_0___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_18 = p_getattr_l__self___up_layers_0___0___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.group_norm.default(conv3d_23, 8, p_getattr_l__self___up_layers_0___0___norm2_weight, p_getattr_l__self___up_layers_0___0___norm2_bias);  conv3d_23 = p_getattr_l__self___up_layers_0___0___norm2_weight = p_getattr_l__self___up_layers_0___0___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_19: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.relu.default(group_norm_19);  group_norm_19 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_24: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.conv3d.default(relu_19, p_getattr_l__self___up_layers_0___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_19 = p_getattr_l__self___up_layers_0___0___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_10: "f32[1, 64, 56, 56, 32]" = torch.ops.aten.add.Tensor(conv3d_24, add_9);  conv3d_24 = add_9 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
        conv3d_25: "f32[1, 32, 56, 56, 32]" = torch.ops.aten.conv3d.default(add_10, p_up_samples_1_0_conv_weight);  add_10 = p_up_samples_1_0_conv_weight = None
        upsample_trilinear3d_1: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_25, None, False, [2.0, 2.0, 2.0]);  conv3d_25 = None
        add_11: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_1, add_2);  upsample_trilinear3d_1 = add_2 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(add_11, 8, p_getattr_l__self___up_layers_1___0___norm1_weight, p_getattr_l__self___up_layers_1___0___norm1_bias);  p_getattr_l__self___up_layers_1___0___norm1_weight = p_getattr_l__self___up_layers_1___0___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_20: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_20);  group_norm_20 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_26: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_20, p_getattr_l__self___up_layers_1___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_20 = p_getattr_l__self___up_layers_1___0___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.group_norm.default(conv3d_26, 8, p_getattr_l__self___up_layers_1___0___norm2_weight, p_getattr_l__self___up_layers_1___0___norm2_bias);  conv3d_26 = p_getattr_l__self___up_layers_1___0___norm2_weight = p_getattr_l__self___up_layers_1___0___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_21: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.relu.default(group_norm_21);  group_norm_21 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_27: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.conv3d.default(relu_21, p_getattr_l__self___up_layers_1___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_21 = p_getattr_l__self___up_layers_1___0___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_12: "f32[1, 32, 112, 112, 64]" = torch.ops.aten.add.Tensor(conv3d_27, add_11);  conv3d_27 = add_11 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:171 in decode, code: x = up(x) + down_x[i + 1]
        conv3d_28: "f32[1, 16, 112, 112, 64]" = torch.ops.aten.conv3d.default(add_12, p_up_samples_2_0_conv_weight);  add_12 = p_up_samples_2_0_conv_weight = None
        upsample_trilinear3d_2: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.upsample_trilinear3d.vec(conv3d_28, None, False, [2.0, 2.0, 2.0]);  conv3d_28 = None
        add_13: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(upsample_trilinear3d_2, add);  upsample_trilinear3d_2 = add = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:86 in forward, code: x = self.norm1(x)
        group_norm_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_13, 8, p_getattr_l__self___up_layers_2___0___norm1_weight, p_getattr_l__self___up_layers_2___0___norm1_bias);  p_getattr_l__self___up_layers_2___0___norm1_weight = p_getattr_l__self___up_layers_2___0___norm1_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:87 in forward, code: x = self.act(x)
        relu_22: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_22);  group_norm_22 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:88 in forward, code: x = self.conv1(x)
        conv3d_29: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_22, p_getattr_l__self___up_layers_2___0___conv1_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_22 = p_getattr_l__self___up_layers_2___0___conv1_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:90 in forward, code: x = self.norm2(x)
        group_norm_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(conv3d_29, 8, p_getattr_l__self___up_layers_2___0___norm2_weight, p_getattr_l__self___up_layers_2___0___norm2_bias);  conv3d_29 = p_getattr_l__self___up_layers_2___0___norm2_weight = p_getattr_l__self___up_layers_2___0___norm2_bias = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:91 in forward, code: x = self.act(x)
        relu_23: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_23);  group_norm_23 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:92 in forward, code: x = self.conv2(x)
        conv3d_30: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_23, p_getattr_l__self___up_layers_2___0___conv2_conv_weight, None, [1, 1, 1], [1, 1, 1]);  relu_23 = p_getattr_l__self___up_layers_2___0___conv2_conv_weight = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/blocks/segresnet_block.py:94 in forward, code: x += identity
        add_14: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.add.Tensor(conv3d_30, add_13);  conv3d_30 = add_13 = None

        # File: /usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py:175 in decode, code: x = self.conv_final(x)
        group_norm_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.group_norm.default(add_14, 8, p_conv_final_0_weight, p_conv_final_0_bias);  add_14 = p_conv_final_0_weight = p_conv_final_0_bias = None
        relu_24: "f32[1, 16, 224, 224, 128]" = torch.ops.aten.relu.default(group_norm_24);  group_norm_24 = None
        conv3d_31: "f32[1, 3, 224, 224, 128]" = torch.ops.aten.conv3d.default(relu_24, p_conv_final_2_conv_weight, p_conv_final_2_conv_bias);  relu_24 = p_conv_final_2_conv_weight = p_conv_final_2_conv_bias = None
        return (conv3d_31,)
justinchuby commented 1 month ago

Yes, thank you

justinchuby commented 1 month ago

That's very strange. If you run torch.onnx.dynamo_export(exported_program, ...), do you get the same graph?

yuanyao-nv commented 1 month ago

@justinchuby Is this the procedure you're suggesting?

    exported_program = torch.export.export(model, (data,))
    exported_program._graph_module.print_readable()

    export_output = torch.onnx.dynamo_export(exported_program, data)
    export_output.save('Clara_SegResNet_dynamo1.onnx')

The exported UpSample module looks about the same as before, still a very big graph. In addition, the weights in the model appear as extra inputs, giving rise to tens of extra model inputs. Similar to this bug https://github.com/pytorch/pytorch/issues/126071

justinchuby commented 1 month ago

I don't see any resize ops, which is puzzling. Could you share the onnx model? You may remove the weights if it is too big

justinchuby commented 1 month ago

I was expecting to see this function: https://github.com/microsoft/onnxscript/blob/2b6dc27b34f2e4e9fc4c3ad73635c5b157a4c714/onnxscript/function_libs/torch_lib/ops/nn.py#L2215

yuanyao-nv commented 1 month ago

@justinchuby I uploaded the two versions of the model here: https://drive.google.com/drive/folders/1s1lhKRuG6fOZmD4IjZvN_zlWfIxPB_8w?usp=sharing

justinchuby commented 1 month ago

It's possible that the upsample op was somehow decomposed by PyTorch. I will look deeper.

borisfom commented 1 month ago

I have found that in general case, one has to run exported_program.run_decompositions() before applying dynamo_export(). That may in fact fold some operations. @yuanyao-nv can you try that ?

justinchuby commented 1 month ago

Thanks. We will be creating a series of changes to the exporter to support ExportedPrograms properly, including handling of the weights.

yuanyao-nv commented 1 month ago

@borisfom I tried running run_decompositions() but it didn't do anything for this particular subgraph.

    exported_program = torch.export.export(model, args=(data,))
    exported_program.run_decompositions()
    export_output = torch.onnx.dynamo_export(exported_program, data)
    export_output.save('Clara_SegResNet_dynamo.onnx')
titaiwangms commented 4 weeks ago

There are two issues here:

  1. Unsupported upsample_trilinear_vec op: #1592
  2. Dynamo forces to decompose upsample related ops for some reasons (related issue: https://github.com/pytorch/pytorch/issues/115883 and https://github.com/pytorch/pytorch/issues/116684). https://github.com/pytorch/pytorch/pull/128259
titaiwangms commented 2 weeks ago

Hi @yuanyao-nv,

This one should be fixed when you call torch.onnx.dynamo_export with nn.Module. However, if you call torch.export.export first, it's going to be decomposed to the big subgraph you had. This decomposition is forced by dynamo for some reasons. Feel free to open an issue like https://github.com/pytorch/pytorch/issues/115883.

cc @gramalingam @justinchuby @xadupre This forcing decomposition would need us to maybe rewriting them as patterns. It will come back to us once we rely on torch.export.export.

yuanyao-nv commented 2 weeks ago

@titaiwangms Thanks for the update.

What's a good way to test it out? If I rerun the export script in the description using the latest torch nightly build (2.5.0.dev20240617+cu121) I actually hit another error

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1509, in dynamo_export
    ).export()
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1236, in export
    graph_module = self.options.fx_tracer.generate_fx(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 214, in generate_fx
    graph_module, graph_guard = torch._dynamo.export(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1379, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 433, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/fx/dynamo_graph_extractor.py", line 169, in wrapped
    return output_adapter.apply(model_func(*args, **kwargs), model=model)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1116, in __call__
    return self._torchdynamo_orig_callable(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 472, in __call__
    return _compile(
  File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 84, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
  File "/usr/local/lib/python3.10/dist-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
  File "/usr/lib/python3.10/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 817, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 231, in time_wrapper
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object
    transformations(instructions, code_options)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 178, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 582, in transform
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2462, in run
    super().run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 510, in wrapper
    return inner_fn(self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1470, in CALL_FUNCTION
    self.call_function(fn, args, {})
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 754, in call_function
    self.push(fn.call_function(self, args, kwargs))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 356, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 298, in call_function
    return super().call_function(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/functions.py", line 95, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 760, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2677, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2793, in inline_call_
    tracer.run()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 904, in run
    while self.step():
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 816, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 234, in impl
    self.push(fn_var.call_function(self, self.popn(nargs), {}))
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 962, in call_function
    return handler(tx, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builtin.py", line 941, in _handle_insert_op_in_graph
    return wrap_fx_proxy(tx, proxy)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1759, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 1846, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1853, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1785, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1300, in wrap_fake_exception
    return fn()
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1786, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1921, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 1903, in run_node
    return node.target(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1061, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1450, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1145, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1757, in _dispatch_impl
    r = func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 670, in __call__
    return self_._op(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 266, in _fn
    result = fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_prims_common/wrappers.py", line 138, in _fn
    result = fn(**bound.arguments)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 1080, in add
    a, b = _maybe_broadcast(a, b)
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 419, in _maybe_broadcast
    common_shape = _broadcast_shapes(
  File "/usr/local/lib/python3.10/dist-packages/torch/_refs/__init__.py", line 408, in _broadcast_shapes
    raise RuntimeError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., device='cuda:0', size=(1, 64, 28, 56, 56, 32),
           grad_fn=<WarnNotImplemented>), FakeTensor(..., device='cuda:0', size=(1, 64, 56, 56, 32),
           grad_fn=<AddBackward0>)), **{}):
Attempting to broadcast a dimension of length 64 at -4! Mismatching argument at index 1 had torch.Size([1, 64, 56, 56, 32]); but expected shape should be broadcastable to [1, 64, 28, 56, 56, 32]

from user code:
   File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 183, in forward
    x = self.decode(x, down_x)
  File "/usr/local/lib/python3.10/dist-packages/monai/networks/nets/segresnet.py", line 171, in decode
    x = up(x) + down_x[i + 1]

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Do you know why fake tensor is being used in the latest torch version?

I also tried exporting just a nn.Upsample function

def f(x):
    m = torch.nn.Upsample(size=(10), mode='linear')
    return m(x)

x = torch.randn(2, 5, 5)
export_output = torch.onnx.dynamo_export(f, x)
export_output.save('Upsample.onnx')

The exported graph looks reasonable. Is this what you'd expect? image

yuanyao-nv commented 1 week ago

Filed a separate issue to track the above fake tensor broadcast error https://github.com/pytorch/pytorch/issues/129534