dmsgnn / master-thesis

MLIR-based FPGA toolchain for Graph Neural Network acceleration using High-Level Synthesis. Developed for the Master of Science research thesis.
0 stars 0 forks source link

Export/Load Model in TorchScript Format #2

Closed dmsgnn closed 1 year ago

dmsgnn commented 1 year ago

description

Since Torch-MLIR most tested path down to Torch MLIR Dialect is the TorchScript one, the aim of this issue is to change the export and load model from state_dict to TorchScript.

dmsgnn commented 1 year ago

problems

Creating the torch script from the actual models implementation is not possible due to some errors. In particular:

  1. RuntimeError: Module 'GCNConv' has no attribute 'inspector' (This attribute exists on the Python module, but we failed to convert Python type: 'torch_geometric.nn.conv.utils.inspector.Inspector' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type Inspector. Its type was inferred; try adding a type annotation for the attribute.):

    File "/Users/dvlpr/miniconda3/envs/thesisPyg/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 534
    
        self._explain = explain
        self.inspector.inspect(self.explain_message, pop_first=True)
        ~~~~~~~~~~~~~~ <--- HERE
        self.__user_args__ = self.inspector.keys(methods).difference(
            self.special_args)
  2. RuntimeError: Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':

    File "/Users/dvlpr/miniconda3/envs/thesisPyg/lib/python3.7/site-packages/ogb/graphproppred/mol_encoder.py", line 22
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:,i])
                           ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    
        return x_embedding
dmsgnn commented 1 year ago

The error 2. can be solved by using a jit interface. The OGB class mol_encoder.py must be changed from

import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims

full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()

class AtomEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()

        self.atom_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_atom_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:,i])

        return x_embedding

class BondEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()

        self.bond_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_bond_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])

        return bond_embedding

if __name__ == '__main__':
    from loader import GraphClassificationPygDataset
    dataset = GraphClassificationPygDataset(name = 'tox21')
    atom_enc = AtomEncoder(100)
    bond_enc = BondEncoder(100)

    print(atom_enc(dataset[0].x))
    print(bond_enc(dataset[0].edge_attr))

to

import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims

full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()

@torch.jit.interface
class ModuleInterface(torch.nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        pass

class AtomEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()

        self.atom_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_atom_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x):
        x_embedding = 0
        for i in range(x.shape[1]):
            submodule: ModuleInterface = self.atom_embedding_list[i]
            x_embedding += submodule.forward(x[:,i])

        return x_embedding

class BondEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()

        self.bond_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(full_bond_feature_dims):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            submodule: ModuleInterface = self.bond_embedding_list[i]
            bond_embedding += submodule.forward(edge_attr[:,i])

        return bond_embedding

if __name__ == '__main__':
    from loader import GraphClassificationPygDataset
    dataset = GraphClassificationPygDataset(name = 'tox21')
    atom_enc = AtomEncoder(100)
    bond_enc = BondEncoder(100)

    print(atom_enc(dataset[0].x))
    print(bond_enc(dataset[0].edge_attr))
dmsgnn commented 1 year ago

alternative solution

Since exporting the model in torch script seemed to be not possible, a new script has been created (for now only for gin model) to train the model and evaluate the train set. This script at the end save the model using torch trace and make a little confirmation about inference result of the original model and the traced one. The main difference of the new script is in the forward function parameter: it does not receive batched data anymore, since it was not traceable, but now receive all the required parameters already split.

the forward function before

    def forward(self, batched_data):

after

    def forward(self, x, edge_index, edge_attr, batch):

This change has affecter the both GNN and GNN_node classes.

dmsgnn commented 1 year ago

The error 1. can be solved by annotating Inspector class with torchscript.jit like this:

import torch

@torch.jit.script
class Inspector(object):
dmsgnn commented 1 year ago

new error

Once solved errors 1. and 2., a new error appear.

RuntimeError: 'Tensor (inferred)' object has no attribute or method 'name'.:

File "/Users/dvlpr/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch_geometric/nn/conv/utils/inspector.py", line 30
    def _implements(self, cls, func_name: str) -> bool:
        if cls.__name__ == 'MessagePassing':
           ~~~~~~~~~~~~ <--- HERE
            return False
        if func_name in cls.__dict__.keys():

common factor

Unless the error number 2., the other errors are caused by the message passing class, which is usually not used in classic neural networks.

dmsgnn commented 1 year ago

update

Looking at the example provided by torch mlir (https://raw.githubusercontent.com/llvm/torch-mlir/main/examples/torchscript_resnet18.py), it seems to not be explicitly necessary to export the model using torch script.

It is probably implicitly done by the framework. If this is the case, the above problems could be still present. To be tested once installed torch mlir.

dmsgnn commented 1 year ago

Issue closed (at the moment) since the torch script has been successfully exported. The changes done to the model are the following:

gin.py

  1. The GINConv component, inside the GNN_node class, has been declared as jittable

before https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/conv.py#L103 after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L138

  1. The parameters of the forward function had to be changed. This function, in order to create the TorchScript, can't take as input only the batched data, because the dictionary is not a scriptable parameter. So, I have decomposed the input in the different elements contained in the batched data. In addition, all the forward parameters have had to be explicitly annotated with their type.

before https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/gnn.py#L63 https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/conv.py#L27 https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/conv.py#L111 after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L96 https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L41 https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L142

  1. The type passed to the propagate function must be explicitly annotated, following the required convention (dictionary or comment). So I added a new line (dictionary) in the GINConv class

after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L26

  1. I have added some 'assert instance of...' for some elements which were not correctly recognized

after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L43 https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L148

  1. Changed the for cycle from integer literal to enumeration for the layer

before https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/conv.py#L117 after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L147

mol_encoder.py

  1. Changed the for cycle in the mol_encoder file. The solution of the jit interface was not working correctly, so the for cycles have been hardcoded (temporary solution but the only working at the moment)

before

        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:,i])
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])

after

            x_embedding += self.atom_embedding_list[0](x[:,0])
            x_embedding += self.atom_embedding_list[1](x[:,1])
            x_embedding += self.atom_embedding_list[2](x[:,2])
            x_embedding += self.atom_embedding_list[3](x[:,3])
            x_embedding += self.atom_embedding_list[4](x[:,4])
            x_embedding += self.atom_embedding_list[5](x[:,5])
            x_embedding += self.atom_embedding_list[6](x[:,6])
            x_embedding += self.atom_embedding_list[7](x[:,7])
            x_embedding += self.atom_embedding_list[8](x[:,8])
            bond_embedding += self.bond_embedding_list[0](edge_attr[:,0])
            bond_embedding += self.bond_embedding_list[1](edge_attr[:,1])
            bond_embedding += self.bond_embedding_list[2](edge_attr[:,2])