snap-stanford / ogb

Benchmark datasets, data loaders, and evaluators for graph machine learning
https://ogb.stanford.edu
MIT License
1.89k stars 397 forks source link

OGB graph prop pred examples torch-mlir compatibility #436

Open dvlp-r opened 1 year ago

dvlp-r commented 1 year ago

Hi, I am opening this issue to ask a question. I am trying to use one of your examples about graph classification (https://github.com/snap-stanford/ogb/tree/master/examples/graphproppred/mol).

What I have recently done is modify the example to make it possible to successfully create a TorchScript of the model. This because I am trying to lowering down the model to torch-mlir. When trying to do it I encounter an error which, as stated by torch-mlir devs, means that my model has a tuple like x=(0,0). They suggested me to try to change this tuple with a list, like x=[0,0]. Unfortunately, I am new into this and I have not been able to spot the problem. I leave the torch-mlir error here for completeness.

Exception: 
PyTorch TorchScript module -> torch-mlir Object Graph IR import failed with:
### Importer C++ Exception:
see diagnostics
### Importer Diagnostics:
error: unhandled prim::Constant node: %37 : (int, int) = prim::Constant[value=(0, 0)]()

Can you please help me to spot this tuple in order to make your models compatible with torch-mlir ? I leave here the files of your model I am using. Some changes has been done only to make it compatible with TorchScript (and, for simplicity, only the code for the GIN model has been preserved).

Thank you in advance for your help.

main.py

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import sys

from PIL import Image
import requests

import torch
from torchvision import transforms
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from gnn import GNN
import torch.optim as optim

from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend

def train(model, device, loader, optimizer, task_type):
    model.train()

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            pred = model(batch)
            optimizer.zero_grad()
            ## ignore nan targets (unlabeled) when computing training loss.
            is_labeled = batch.y == batch.y
            if "classification" in task_type:
                loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            else:
                loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            loss.backward()
            optimizer.step()

def eval(model, device, loader, evaluator):
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)

        if batch.x.shape[0] == 1:
            pass
        else:
            x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch
            with torch.no_grad():
                pred = model(x, edge_index, edge_attr, batch_f)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)

def predictions(torch_model, jit_model):
    pytorch_prediction = eval(torch_model, device, test_loader, evaluator)
    print("PyTorch prediction")
    print(pytorch_prediction)
    mlir_prediction = eval(jit_model, device, test_loader, evaluator)
    print("torch-mlir prediction")
    print(mlir_prediction)

cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()

### automatic data loading and splitting
dataset = PygGraphPropPredDataset(name='ogbg-molhiv')

split_idx = dataset.get_idx_split()

### automatic evaluator. takes dataset name as input
evaluator = Evaluator("ogbg-molhiv")

train_loader = DataLoader(dataset[split_idx["train"]], batch_size=1, shuffle=True,
                          num_workers=0)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=1, shuffle=False,
                          num_workers=0)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=1, shuffle=False,
                         num_workers=0)

gin = GNN(gnn_type='gin', num_tasks=dataset.num_tasks, num_layer=5, emb_dim=300,
          drop_ratio=0.5).to("cpu")

optimizer = optim.Adam(gin.parameters(), lr=0.001)

device = torch.device("cpu")

train(gin, device, train_loader, optimizer, dataset.task_type)
eval(gin, device, valid_loader, evaluator)

gin.eval()

for step, batch in enumerate(tqdm(test_loader, desc="Iteration")):
    batch = batch.to(device)
    x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch
    module = torch_mlir.compile(gin, (x, edge_index, edge_attr, batch_f), output_type="linalg-on-tensors")
    break

backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)

predictions(gin, jit_module)

gnn.py

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_geometric.nn.inits import uniform

from conv import GNN_node

from torch_scatter import scatter_mean
import time

class GNN(torch.nn.Module):

    def __init__(self, num_tasks=10, num_layer=5, emb_dim=300,
                 gnn_type='gin', residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"):
        """
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        """

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        ### GNN to generate node embeddings
        self.gnn_node = GNN_node(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual,
                                 gnn_type=gnn_type)

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(
                gate_nn=torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                                            torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid grcd ..aph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, batch) -> torch.Tensor:

        h_node = self.gnn_node(x, edge_index, edge_attr)

        h_graph = self.pool(h_node, batch)

        return self.graph_pred_linear(h_graph)

if __name__ == '__main__':
    GNN(num_tasks=10)

conv.py

import torch
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_add_pool
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from torch_geometric.utils import degree

import math
import time

### GIN convolution along the graph structure
class GINConv(MessagePassing):
    propagate_type = {'x': torch.Tensor, 'edge_attr': torch.Tensor}

    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''

        super(GINConv, self).__init__(aggr="add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                                       torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.bond_encoder = BondEncoder(emb_dim=emb_dim)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        edge_embedding = self.bond_encoder(edge_attr)
        assert isinstance(edge_embedding, torch.Tensor)
        return self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding, size=None))

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

### GNN to generate node embedding
class GNN_node(torch.nn.Module):
    """
    Output:
        node representations
    """

    def __init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False, gnn_type='gin'):
        '''
            emb_dim (int): node embedding dimensionality
            num_layer (int): number of GNN message passing layers
        '''

        super(GNN_node, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ###List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for _ in range(num_layer):
            self.convs.append(GINConv(emb_dim).jittable())

            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        ### computing input node embedding

        h_list = [self.atom_encoder(x)]

        for layer, (conv, norm) in enumerate(zip(self.convs, self.batch_norms)):

            assert isinstance(h_list[layer], torch.Tensor)
            h = conv(h_list[layer], edge_index, edge_attr)
            h = norm(h)

            if layer == self.num_layer - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        ### Different implementations of Jk-concat
        #if self.JK == "last":
        node_representation = h_list[-1]
        #elif self.JK == "sum":
        #    node_representation = 0
        #    for layer in range(self.num_layer + 1):
        #        node_representation += h_list[layer]

        return node_representation

if __name__ == "__main__":
    pass

mol_encoder.py

import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims 

full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()

class AtomEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()

        self.atom_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_atom_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0
        x_embedding += self.atom_embedding_list[0](x[:,0])
        x_embedding += self.atom_embedding_list[1](x[:,1])
        x_embedding += self.atom_embedding_list[2](x[:,2])
        x_embedding += self.atom_embedding_list[3](x[:,3])
        x_embedding += self.atom_embedding_list[4](x[:,4])
        x_embedding += self.atom_embedding_list[5](x[:,5])
        x_embedding += self.atom_embedding_list[6](x[:,6])
        x_embedding += self.atom_embedding_list[7](x[:,7])
        x_embedding += self.atom_embedding_list[8](x[:,8])

        return x_embedding

class BondEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()

        self.bond_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_bond_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        bond_embedding += self.bond_embedding_list[0](edge_attr[:,0])
        bond_embedding += self.bond_embedding_list[1](edge_attr[:,1])
        bond_embedding += self.bond_embedding_list[2](edge_attr[:,2])

        return bond_embedding   

if __name__ == '__main__':
    from loader import GraphClassificationPygDataset
    dataset = GraphClassificationPygDataset(name = 'tox21')
    atom_enc = AtomEncoder(100)
    bond_enc = BondEncoder(100)

    print(atom_enc(dataset[0].x))
    print(bond_enc(dataset[0].edge_attr))