Closed dmsgnn closed 1 year ago
Creating the torch script from the actual models implementation is not possible due to some errors. In particular:
RuntimeError: Module 'GCNConv' has no attribute 'inspector' (This attribute exists on the Python module, but we failed to convert Python type: 'torch_geometric.nn.conv.utils.inspector.Inspector' to a TorchScript type. Only tensors and (possibly nested) tuples of tensors, lists, or dictsare supported as inputs or outputs of traced functions, but instead got value of type Inspector. Its type was inferred; try adding a type annotation for the attribute.):
File "/Users/dvlpr/miniconda3/envs/thesisPyg/lib/python3.7/site-packages/torch_geometric/nn/conv/message_passing.py", line 534
self._explain = explain
self.inspector.inspect(self.explain_message, pop_first=True)
~~~~~~~~~~~~~~ <--- HERE
self.__user_args__ = self.inspector.keys(methods).difference(
self.special_args)
RuntimeError: Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
File "/Users/dvlpr/miniconda3/envs/thesisPyg/lib/python3.7/site-packages/ogb/graphproppred/mol_encoder.py", line 22
x_embedding = 0
for i in range(x.shape[1]):
x_embedding += self.atom_embedding_list[i](x[:,i])
~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
return x_embedding
The error 2. can be solved by using a jit interface. The OGB class mol_encoder.py
must be changed from
import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()
class AtomEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_atom_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
for i in range(x.shape[1]):
x_embedding += self.atom_embedding_list[i](x[:,i])
return x_embedding
class BondEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(BondEncoder, self).__init__()
self.bond_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_bond_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)
def forward(self, edge_attr):
bond_embedding = 0
for i in range(edge_attr.shape[1]):
bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])
return bond_embedding
if __name__ == '__main__':
from loader import GraphClassificationPygDataset
dataset = GraphClassificationPygDataset(name = 'tox21')
atom_enc = AtomEncoder(100)
bond_enc = BondEncoder(100)
print(atom_enc(dataset[0].x))
print(bond_enc(dataset[0].edge_attr))
to
import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
pass
class AtomEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_atom_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
for i in range(x.shape[1]):
submodule: ModuleInterface = self.atom_embedding_list[i]
x_embedding += submodule.forward(x[:,i])
return x_embedding
class BondEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(BondEncoder, self).__init__()
self.bond_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_bond_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)
def forward(self, edge_attr):
bond_embedding = 0
for i in range(edge_attr.shape[1]):
submodule: ModuleInterface = self.bond_embedding_list[i]
bond_embedding += submodule.forward(edge_attr[:,i])
return bond_embedding
if __name__ == '__main__':
from loader import GraphClassificationPygDataset
dataset = GraphClassificationPygDataset(name = 'tox21')
atom_enc = AtomEncoder(100)
bond_enc = BondEncoder(100)
print(atom_enc(dataset[0].x))
print(bond_enc(dataset[0].edge_attr))
Since exporting the model in torch script seemed to be not possible, a new script has been created (for now only for gin model) to train the model and evaluate the train set. This script at the end save the model using torch trace and make a little confirmation about inference result of the original model and the traced one. The main difference of the new script is in the forward function parameter: it does not receive batched data anymore, since it was not traceable, but now receive all the required parameters already split.
the forward function before
def forward(self, batched_data):
after
def forward(self, x, edge_index, edge_attr, batch):
This change has affecter the both GNN and GNN_node classes.
The error 1. can be solved by annotating Inspector class with torchscript.jit like this:
import torch
@torch.jit.script
class Inspector(object):
Once solved errors 1. and 2., a new error appear.
RuntimeError: 'Tensor (inferred)' object has no attribute or method 'name'.:
File "/Users/dvlpr/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch_geometric/nn/conv/utils/inspector.py", line 30
def _implements(self, cls, func_name: str) -> bool:
if cls.__name__ == 'MessagePassing':
~~~~~~~~~~~~ <--- HERE
return False
if func_name in cls.__dict__.keys():
Unless the error number 2., the other errors are caused by the message passing class, which is usually not used in classic neural networks.
Looking at the example provided by torch mlir (https://raw.githubusercontent.com/llvm/torch-mlir/main/examples/torchscript_resnet18.py), it seems to not be explicitly necessary to export the model using torch script.
It is probably implicitly done by the framework. If this is the case, the above problems could be still present. To be tested once installed torch mlir.
Issue closed (at the moment) since the torch script has been successfully exported. The changes done to the model are the following:
before https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/conv.py#L103 after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L138
before https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/gnn.py#L63 https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/conv.py#L27 https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/conv.py#L111 after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L96 https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L41 https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L142
after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L26
after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L43 https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L148
before https://github.com/dmsgnn/master-thesis/blob/53f4541e9870f17c1010eee2a1dcc137712cfdff/conv.py#L117 after https://github.com/dmsgnn/master-thesis/blob/f0cd703f6b2ce43dd09deef75c63ae96c5ae9e68/gnn/gin.py#L147
before
for i in range(x.shape[1]):
x_embedding += self.atom_embedding_list[i](x[:,i])
for i in range(edge_attr.shape[1]):
bond_embedding += self.bond_embedding_list[i](edge_attr[:,i])
after
x_embedding += self.atom_embedding_list[0](x[:,0])
x_embedding += self.atom_embedding_list[1](x[:,1])
x_embedding += self.atom_embedding_list[2](x[:,2])
x_embedding += self.atom_embedding_list[3](x[:,3])
x_embedding += self.atom_embedding_list[4](x[:,4])
x_embedding += self.atom_embedding_list[5](x[:,5])
x_embedding += self.atom_embedding_list[6](x[:,6])
x_embedding += self.atom_embedding_list[7](x[:,7])
x_embedding += self.atom_embedding_list[8](x[:,8])
bond_embedding += self.bond_embedding_list[0](edge_attr[:,0])
bond_embedding += self.bond_embedding_list[1](edge_attr[:,1])
bond_embedding += self.bond_embedding_list[2](edge_attr[:,2])
description
Since Torch-MLIR most tested path down to Torch MLIR Dialect is the TorchScript one, the aim of this issue is to change the export and load model from
state_dict
to TorchScript.