pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.93k stars 3.61k forks source link

Exporting to ONNX doesn't yield any errors, but also doesn't generate an output file #5656

Closed Wantcha closed 1 year ago

Wantcha commented 1 year ago

🐛 Describe the bug

I'm running a pretty standard GNN model and I'm trying to export it to ONNX using:

input_values = (g.x, g.edge_index, g.edge_attr)
input_names = ['node_attr', 'edge_index', 'edge_attr']
torch.onnx.export(model, input_values, "physics_model.onnx", opset_version=16, input_names=input_names,
                        output_names=['coords'], dynamic_axes={'node_attr':{0:'num_nodes'}, 'edge_index':{1:'num_edges'}, 'edge_attr':{0:'num_edges'}}, verbose=True)
print('done')

But the 'done' is never reached. The console outputs some tensors, then suddenly stops the execution of the rest of the program. I've run the model myself before and it works fine. What could be the reason? I found this similar issue: https://github.com/ultralytics/yolov5/issues/9630 , but I couldn't understand how to implement the author's solution.

Environment

JiaxuanYou commented 1 year ago

Could you provide a minimal example to reproduce your bug? Thank you! cc @rusty1s

Wantcha commented 1 year ago

Could you provide a minimal example to reproduce your bug? Thank you! cc @rusty1s

from typing import OrderedDict
import numpy as np
import torch as th
import torch.nn as nn
import torch.onnx
import torch_geometric.data as tgd
from torch_geometric.nn import MessagePassing

class MLP(nn.Module):
    '''
    Multilayer Perceptron.
    '''
    def __init__(self, hidden_size: int, num_hidden_layers: int, output_size: int):
        super(MLP, self).__init__()
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.output_size = output_size

        self.initialized = False

    def _initialize(self, inputs : th.Tensor):
        if not self.initialized:
            input_size = inputs.shape[1]

            l = OrderedDict()
            l['input'] = nn.Linear(input_size, self.hidden_size)
            l['relu_in'] = nn.ReLU()
            for i in range(self.num_hidden_layers):
                l['h%d' % i] = nn.Linear(self.hidden_size, self.hidden_size)
                l['relu%d' % i] = nn.ReLU()
            l['out'] = nn.Linear(self.hidden_size, self.output_size)

            self.layers = nn.Sequential(l)
            self.initialized = True
            print("INITIALIZED MLP")

    def forward(self, x):
        self._initialize(x)

        return self.layers(x)

def build_mlp_with_layer_norm(hidden_size: int, num_hidden_layers: int, output_size: int) -> th.nn.Module:
    mlp = MLP(hidden_size, num_hidden_layers, output_size)
    return th.nn.Sequential( mlp, th.nn.LayerNorm(output_size) )

class InteractionNetworkModule(MessagePassing):
    def __init__(self, node_model, edge_model):
        super(InteractionNetworkModule, self).__init__(aggr = 'add')
        self.node_model = node_model
        self.edge_model = edge_model

    def forward(self, x : th.Tensor, edge_index: th.Tensor, edge_feats: th.Tensor):
        src, dst = edge_index
        collectedEdgeFeats = th.concat([edge_feats, x[src]], dim=1)

        new_edge_feats = self.edge_model(collectedEdgeFeats)

        num_nodes = x.size(0)

        out_nodes = self.propagate(edge_index=edge_index, size=(num_nodes, num_nodes), new_edge_feats=new_edge_feats, x=x)

        return out_nodes, new_edge_feats

    def message(self, edge_index_j: th.Tensor, new_edge_feats: th.Tensor) -> th.Tensor:
        return new_edge_feats[edge_index_j]

    def update(self, aggr_out: th.Tensor, x: th.Tensor) -> th.Tensor:
        collected_node_feats = th.concat([x, aggr_out], dim=1)
        new_nodes: th.Tensor = self.node_model(collected_node_feats)

        return new_nodes

class GraphIndependentModule(nn.Module):
    def __init__(self, node_model, edge_model):
        super(GraphIndependentModule, self).__init__()
        self.node_model = node_model
        self.edge_model = edge_model

    def forward(self, x : th.Tensor, edge_feats: th.Tensor):
        x = self.node_model(x)
        edge_feats = self.edge_model(edge_feats)
        return x, edge_feats

class GraphNetwork(nn.Module):
    def __init__(self):
        super(GraphNetwork, self).__init__()

        mlp_hidden_size = 128
        mlp_num_hidden_layers = 2
        mlp_latent_size = 128
        message_passing_steps = 5

        edge_encode_model = build_mlp_with_layer_norm(mlp_hidden_size, mlp_num_hidden_layers, mlp_latent_size)
        node_encode_model = build_mlp_with_layer_norm(mlp_hidden_size, mlp_num_hidden_layers, mlp_latent_size)
        self.encoder_network = GraphIndependentModule(node_encode_model, edge_encode_model)

        self.processor_networks = []
        for _ in range(message_passing_steps):
            self.processor_networks.append(
                InteractionNetworkModule(build_mlp_with_layer_norm(mlp_hidden_size, mlp_num_hidden_layers, mlp_latent_size),
                                            build_mlp_with_layer_norm(mlp_hidden_size, mlp_num_hidden_layers, mlp_latent_size)))

    def forward(self, x: th.Tensor, edge_index: th.Tensor, edge_attr : th.Tensor) -> th.Tensor:
        node_feats = x.clone().detach()
        edge_feats = edge_attr.clone().detach()

        node_feats, edge_feats = self.encoder_network(node_feats, edge_feats)

        for processor_network in self.processor_networks:
            processed_node_feats, processed_edge_feats = processor_network(node_feats, edge_index, edge_feats)
            node_feats = node_feats + processed_node_feats
            edge_feats = edge_feats + processed_edge_feats

        return node_feats

if __name__ == "__main__":

    num_edges = np.random.randint(1000, 5000)
    num_nodes = 300
    edge_index = th.randint(0, num_nodes, (2, num_edges))

    edge_attr = th.rand(num_edges, 1)

    x = th.rand(num_nodes, 9)
    pinned_points = th.randint(0, 1, (num_nodes,))

    model = GraphNetwork()
    input_values = (x, edge_index, edge_attr)
    input_names = ['node_attr', 'edge_index', 'edge_attr']

    result = model(x, edge_index, edge_attr).detach().numpy()

    torch.onnx.export(model, input_values, "H:\\Animating Tools\\Projects\\Houdini\\LearningPhysics\\scripts\\physics_model.onnx", opset_version=16, input_names=input_names,
                        output_names=['coords'], dynamic_axes={'node_attr':{0:'num_nodes'}, 'edge_index':{1:'num_edges'}, 'edge_attr':{0:'num_edges'}}, verbose=True)

    print('done')

Apologies for the lengthy code, I wasn't sure how to shorten it further while still keeping the bug intact. I've noticed while testing that the bug seems to go away when instead of making a list of Processor_Networks, I only use one. If I only use one, the exporter does give me an error, but it does seem to export the file successfully.

rusty1s commented 1 year ago

Thanks for the example. We do not support ONNX officially, but it looks like it is generally supported, see https://github.com/pyg-team/pytorch_geometric/issues/728. Is it possible to identify which child module currently breaks ONNX export for you?

Wantcha commented 1 year ago

Thanks for the example. We do not support ONNX officially, but it looks like it is generally supported, see #728. Is it possible to identify which child module currently breaks ONNX export for you?

All I know is that the problem seems to start occurring once I have a series of InteractionNetworks chained together in a List. If I just use one as the processor_network, it looks like it's fine

Wantcha commented 1 year ago

Update: it seems like the InteractionNetwork by itself works just fine, but having them stored in a List seems to cause everything to break. If I have 1 regular IntreractionNetwork, it works just fine. If I have 1 InteractionNetwork but stored in a List (of 1 element), and run it from that List, it seems to break

Wantcha commented 1 year ago

SOLVED: using a PyTorch ModuleList instead of a Python List looks like it's the way to go. For some reason, ONNX's tracing doesn't recognize Python Lists properly!