pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
21.21k stars 3.64k forks source link

TorhScript error in self.propagate #7137

Closed dvlp-r closed 1 year ago

dvlp-r commented 1 year ago

🐛 Describe the bug

Hi, I am trying to save the torchScript of a GNN model created with PyTorch geometric. The model is the one provided by OGB in its GitHub repository (https://github.com/snap-stanford/ogb/tree/master/examples/graphproppred/mol).

After having done a modification in the OGB mol_encoder class (used a jit interface to avoid the expected integer literal for index error) and set the re-defined GINConv operator as jittable() and added the propagate and forward types, I am now encountering this error and I do not know how to solve it.

RuntimeError: 
Arguments for call are not valid.
The following variants are available:

  propagate__0(__torch__.GINConvJittable_f74bb0.GINConvJittable_f74bb0 self, Tensor edge_index, Tensor x, Tensor edge_attr, (int, int)? size) -> Tensor:
  Expected a value of type 'Tensor' for argument 'edge_attr' but instead found type 'int'.

  propagate__1(__torch__.GINConvJittable_f74bb0.GINConvJittable_f74bb0 self, __torch__.torch_sparse.tensor.SparseTensor edge_index, Tensor x, Tensor edge_attr, (int, int)? size) -> Tensor:
  Expected a value of type '__torch__.torch_sparse.tensor.SparseTensor (of Python compilation unit at: 0x600001f88018)' for argument 'edge_index' but instead found type 'Tensor'.

The original call is:
  File "/var/folders/fv/2kfwzrfd34g51h7_rsfb3pb00000gn/T/dvlpr_pyg/tmpk2s6lstw.py", line 205
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        edge_embedding = self.bond_encoder(edge_attr)
        return self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding, size=None))
                                             ~~~~~~~~~~~~~~ <--- HERE

Thank you in advance for your help, if you do not know about this error but are able to provide me a working solution to save the TorchScript of the OGB model I linked, I would be really thankful.

Thanks everyone in advance for your help.

Environment

rusty1s commented 1 year ago

What happens if you run this as


edge_embedding = self.bond_encoder(edge_attr)
assert isinstance(edge_embedding, Tensor)
return self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding, size=None))
dvlp-r commented 1 year ago

Hi, thank you for your reply. Your suggestion actually helped me, but I am facing several problems. Now I have created a file to try to export the gin model.

import torch
from ogb.graphproppred.mol_encoder import BondEncoder, AtomEncoder
from torch_geometric.loader import DataLoader
import torch.optim as optim
from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, \
    Set2Set
import torch.nn.functional as F

from gnn import GNN

import argparse
import numpy as np
from tqdm import tqdm

# importing OGB
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator

cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()

### GIN convolution along the graph structure
class GINConv(MessagePassing):
    propagate_type = {'x': torch.Tensor, 'edge_attr': torch.Tensor}
    def __init__(self, emb_dim):
        '''
            emb_dim (int): node embedding dimensionality
        '''

        super(GINConv, self).__init__(aggr="add")

        self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                                       torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

        self.bond_encoder = BondEncoder(emb_dim=emb_dim)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        edge_embedding = self.bond_encoder(edge_attr)
        assert isinstance(edge_embedding, torch.Tensor)
        return self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding, size=None))

    def message(self, x_j, edge_attr):
        return F.relu(x_j + edge_attr)

    def update(self, aggr_out):
        return aggr_out

class GNN(torch.nn.Module):

    def __init__(self, num_tasks=10, num_layer=5, emb_dim=300,
                 gnn_type='gin', residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"):
        """
            num_tasks (int): number of labels to be predicted
            virtual_node (bool): whether to add virtual node or not
        """

        super(GNN, self).__init__()

        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        self.emb_dim = emb_dim
        self.num_tasks = num_tasks
        self.graph_pooling = graph_pooling

        ### GNN to generate node embeddings
        self.gnn_node = GNN_node(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual,
                                 gnn_type=gnn_type)

        ### Pooling function to generate whole-graph embeddings
        if self.graph_pooling == "sum":
            self.pool = global_add_pool
        elif self.graph_pooling == "mean":
            self.pool = global_mean_pool
        elif self.graph_pooling == "max":
            self.pool = global_max_pool
        elif self.graph_pooling == "attention":
            self.pool = GlobalAttention(
                gate_nn=torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
                                            torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, 1)))
        elif self.graph_pooling == "set2set":
            self.pool = Set2Set(emb_dim, processing_steps=2)
        else:
            raise ValueError("Invalid graph pooling type.")

        if graph_pooling == "set2set":
            self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim, self.num_tasks)
        else:
            self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, batch) -> torch.Tensor:
        """
        forward function used for prediction
        """
        h_node = self.gnn_node(x, edge_index, edge_attr)

        h_graph = self.pool(h_node, batch)

        return self.graph_pred_linear(h_graph)

### GNN to generate node embedding
class GNN_node(torch.nn.Module):
    """
    Output:
        node representations
    """

    def __init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False, gnn_type='gin'):
        '''
            emb_dim (int): node embedding dimensionality
            num_layer (int): number of GNN message passing layers
        '''

        super(GNN_node, self).__init__()
        self.num_layer = num_layer
        self.drop_ratio = drop_ratio
        self.JK = JK
        ### add residual connection or not
        self.residual = residual
        self.layers = [1, 2, 3, 4, 5]

        if self.num_layer < 2:
            raise ValueError("Number of GNN layers must be greater than 1.")

        self.atom_encoder = AtomEncoder(emb_dim)

        ###List of GNNs
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()

        for layer in range(num_layer):
            self.convs.append(GINConv(emb_dim).jittable())

            self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
        ### computing input node embedding

        h_list = [self.atom_encoder(x)]
        for layer in range(self.num_layer):

            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
            h = self.batch_norms[layer](h)

            if layer == self.num_layer - 1:
                # remove relu for the last layer
                h = F.dropout(h, self.drop_ratio, training=self.training)
            else:
                h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)

            if self.residual:
                h += h_list[layer]

            h_list.append(h)

        ### Different implementations of Jk-concat
        if self.JK == "last":
            node_representation = h_list[-1]
        elif self.JK == "sum":
            node_representation = 0
            for layer in range(self.num_layer + 1):
                node_representation += h_list[layer]

        return node_representation

def train(model, device, loader, optimizer, task_type):
    """
    function used to train a model
    """
    model.train()

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch

        if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
            pass
        else:
            pred = model(x, edge_index, edge_attr, batch_f)
            optimizer.zero_grad()
            ## ignore nan targets (unlabeled) when computing training loss.
            is_labeled = batch.y == batch.y
            if "classification" in task_type:
                loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            else:
                loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
            loss.backward()
            optimizer.step()

def eval(model, device, loader, evaluator):
    """
    function used to evaluate the model and make inference
    """
    model.eval()
    y_true = []
    y_pred = []

    for step, batch in enumerate(tqdm(loader, desc="Iteration")):
        batch = batch.to(device)
        x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch

        if batch.x.shape[0] == 1:
            pass
        else:
            with torch.no_grad():
                pred = model(x, edge_index, edge_attr, batch_f)

            y_true.append(batch.y.view(pred.shape).detach().cpu())
            y_pred.append(pred.detach().cpu())

    y_true = torch.cat(y_true, dim=0).numpy()
    y_pred = torch.cat(y_pred, dim=0).numpy()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)

def main():
    """
    main function which allow the user to train a new model and make inference (graph classification)
    """
    # Training settings
    parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
    parser.add_argument('--drop_ratio', type=float, default=0.5,
                        help='dropout ratio (default: 0.5)')
    parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5)')
    parser.add_argument('--emb_dim', type=int, default=300,
                        help='dimensionality of hidden units in GNNs (default: 300)')
    parser.add_argument('--batch_size', type=int, default=32,
                        help='input batch size for training (default: 32)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
    parser.add_argument('--num_workers', type=int, default=0,
                        help='number of workers (default: 0)')
    args = parser.parse_args()

    device = torch.device("cpu")

    ### automatic data loading and splitting
    dataset = PygGraphPropPredDataset(name="ogbg-molhiv")

    split_idx = dataset.get_idx_split()

    ### automatic evaluator. takes dataset name as input
    evaluator = Evaluator("ogbg-molhiv")

    train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True,
                              num_workers=args.num_workers)
    valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False,
                              num_workers=args.num_workers)
    test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False,
                             num_workers=args.num_workers)

    model = GNN(gnn_type='gin', num_tasks=dataset.num_tasks, num_layer=args.num_layer, emb_dim=args.emb_dim,
                drop_ratio=args.drop_ratio).to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.001)

    valid_curve = []
    test_curve = []
    train_curve = []

    model_scripted = torch.jit.script(model)     # Export to TorchScript
    model_scripted.save("gin-script.pt")    # Save

    for epoch in range(1, 1 + 1):
        print("=====Epoch {}".format(epoch))
        print('Training...')
        train(model, device, train_loader, optimizer, dataset.task_type)

        print('Evaluating...')
        train_perf = eval(model, device, train_loader, evaluator)
        # valid_perf = eval(model, device, valid_loader, evaluator)
        # test_perf = eval(model, device, test_loader, evaluator)

        # print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})

        # train_curve.append(train_perf[dataset.eval_metric])
        # valid_curve.append(valid_perf[dataset.eval_metric])
        # test_curve.append(test_perf[dataset.eval_metric])

    # if 'classification' in dataset.task_type:
    #    best_val_epoch = np.argmax(np.array(valid_curve))
    #    best_train = max(train_curve)
    # else:
    #    best_val_epoch = np.argmin(np.array(valid_curve))
    #    best_train = min(train_curve)

    print('Finished training!')
    # print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
    # print('Test score: {}'.format(test_curve[best_val_epoch]))

if __name__ == "__main__":
    main()

I am now blocked with the following error.

Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
  File "/Users/dvlpr/PyCharmProjects/gnn-acceleration-master-thesis/gnn/gin.py", line 146
        for layer in range(self.num_layer):

            h = self.convs[layer](h_list[layer], edge_index, edge_attr)
                ~~~~~~~~~~~~~~~~~ <--- HERE
            h = self.batch_norms[layer](h)

I have faced this problem also in the mol_encoder file but I have been able to solve it in the following way, but now I am not able to make it work.

import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims

full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()

@torch.jit.interface
class ModuleInterface(torch.nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        pass

class AtomEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()

        self.atom_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_atom_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0
        for i in range(x.shape[1]):
            submodule: ModuleInterface = self.atom_embedding_list[i]
            x_embedding += submodule.forward(x[:,i])

        return x_embedding

class BondEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()

        self.bond_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_bond_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            submodule: ModuleInterface = self.bond_embedding_list[i]
            bond_embedding += submodule.forward(edge_attr[:,i])

        return bond_embedding

if __name__ == '__main__':
    from loader import GraphClassificationPygDataset
    dataset = GraphClassificationPygDataset(name = 'tox21')
    atom_enc = AtomEncoder(100)
    bond_enc = BondEncoder(100)

    print(atom_enc(dataset[0].x))
    print(bond_enc(dataset[0].edge_attr))

Thank you in advance for your help.

rusty1s commented 1 year ago

Yeah, it can get a bit tricky to make TorchScript work. In your case, I believe you need to directly iterate over convs and batch_norms:


for conv, norm in zip(self.convs, self.batch_norms):
     ...
dvlp-r commented 1 year ago

Hi @rusty1s, thank you for your help. I have finally saved my model in TorchScript. However, I still have two problems and I hope you can help me with them.

When I try to re-load my script and make inference, the following error appears

Traceback (most recent call last):
  File "/Users/dvlpr/PyCharmProjects/.../gnn/gin.py", line 327, in <module>
    main()
  File "/Users/dvlpr/PyCharmProjects/.../gnn/gin.py", line 319, in main
    print(model_load(x, edge_index, edge_attr, batch_f))
  File "/Users/dvlpr/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__.py", line 21, in forward
    _0 = __torch__.torch_geometric.nn.pool.glob.global_mean_pool
    gnn_node = self.gnn_node
    h_node = (gnn_node).forward(x, edge_index, edge_attr, )
              ~~~~~~~~~~~~~~~~~ <--- HERE
    h_graph = _0(h_node, batch, None, )
    graph_pred_linear = self.graph_pred_linear
  File "code/__torch__.py", line 45, in forward
    _2 = uninitialized(Tensor)
    atom_encoder = self.atom_encoder
    _3 = (atom_encoder).forward(x, )
          ~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    ops.prim.RaiseException("AssertionError: ")
    return _2
  File "code/__torch__/ogb/graphproppred/mol_encoder.py", line 15, in forward
      _0 = torch.select(torch.slice(x), 1, i)
      x_embedding0 = torch.add((submodule).forward(_0, ), x_embedding)
      x_embedding = annotate(int, x_embedding0)
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    return x_embedding
class BondEncoder(Module):

Traceback of TorchScript, original code (most recent call last):
  File "/Users/dvlpr/PyCharmProjects/.../gnn/gin.py", line 99, in forward
        """

        h_node = self.gnn_node(x, edge_index, edge_attr)
                 ~~~~~~~~~~~~~ <--- HERE

        h_graph = self.pool(h_node, batch)
  File "/Users/dvlpr/PyCharmProjects/.../gnn/gin.py", line 144, in forward
        ### computing input node embedding

        h_list = [self.atom_encoder(x)]
                  ~~~~~~~~~~~~~~~~~ <--- HERE

        for layer, (conv, norm) in enumerate(zip(self.convs, self.batch_norms)):
  File "/Users/dvlpr/miniconda3/envs/pytorch/lib/python3.10/site-packages/ogb/graphproppred/mol_encoder.py", line 28, in forward
        for i in range(x.shape[1]):
            submodule: ModuleInterface = self.atom_embedding_list[i]
            x_embedding += submodule.forward(x[:,i])
            ~~~~~~~~~~~ <--- HERE

        return x_embedding
RuntimeError: Cannot input a tensor of dimension other than 0 as a scalar argument

Looking on the web I have not found any solution for it. Also, what I am trying to do is lowering this model down to MLIR using torch-mlir. However, when trying to do it, an error tell me that somewhere in my model I have a tuple that looks like x = (0,0), but I do not recognize any tuple like this. Am I missing something?

rusty1s commented 1 year ago

I am not entirely sure. Can you try to replace += with add_()?


x_embedding.add_(submodule.forward(x[:,i]))
dvlp-r commented 1 year ago

Unfortunately it does not work since x_embedding is an integer.

  File "/Users/dvlpr/miniconda3/envs/pytorch/lib/python3.10/site-packages/ogb/graphproppred/mol_encoder.py", line 28, in forward
    x_embedding.add_(submodule.forward(x[:,i]))
AttributeError: 'int' object has no attribute 'add_'
rusty1s commented 1 year ago

Why is x_embedding an integer? it should be a tensor IMO.

dvlp-r commented 1 year ago

that code is part of the mol_encoder.py file (used for the MolHIV dataset) provided by OGB. The whole code I am using is provided by Open Graph Benchmark. The only change I have made is using the jit interface to create the torch script.

rusty1s commented 1 year ago

I see, I think this should be changed to:


x_embeddings = [self.atom_embedding_list[i](x[:, i]) for i in rang(x.shape[1])
return sum(x_embeddings)
dvlp-r commented 1 year ago

hi @rusty1s, thank you for your help. Apparently the problem was the jit interface I was using in the mol_encoder.py file. After some testing I found that deleting the jit interface and hardcoding the for cycles solved the problem. I can now use the scripted model to correctly make inference. Last thing that I would like to ask you is if you have any suggestion about the error I am receiving when trying to lowering down the model to torch-mlir. I know it is not the purpose of this issue but I am not figuring out where a tuple like (0,0) could be in my model. Anyway, I really thank you so much for your help. Feel free to close this issue.

rusty1s commented 1 year ago

I am not familiar with torch-mlir at all :(

dvlp-r commented 1 year ago

thank you anyway, I will try to ask on the OGB repo to see if someone has experience with it. Thank you again!