What I have recently done is modify the example to make it possible to successfully create a TorchScript of the model. This because I am trying to lowering down the model to torch-mlir. When trying to do it I encounter an error which, as stated by torch-mlir devs, means that my model has a tuple like x=(0,0). They suggested me to try to change this tuple with a list, like x=[0,0].
Unfortunately, I am new into this and I have not been able to spot the problem. I leave the torch-mlir error here for completeness.
Can you please help me to spot this tuple in order to make your models compatible with torch-mlir ?
I leave here the files of your model I am using. Some changes has been done only to make it compatible with TorchScript (and, for simplicity, only the code for the GIN model has been preserved).
Thank you in advance for your help.
main.py
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import sys
from PIL import Image
import requests
import torch
from torchvision import transforms
from tqdm import tqdm
from torch_geometric.loader import DataLoader
from gnn import GNN
import torch.optim as optim
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
import torch_mlir
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
def train(model, device, loader, optimizer, task_type):
model.train()
for step, batch in enumerate(tqdm(loader, desc="Iteration")):
batch = batch.to(device)
if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
pass
else:
pred = model(batch)
optimizer.zero_grad()
## ignore nan targets (unlabeled) when computing training loss.
is_labeled = batch.y == batch.y
if "classification" in task_type:
loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
else:
loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
loss.backward()
optimizer.step()
def eval(model, device, loader, evaluator):
model.eval()
y_true = []
y_pred = []
for step, batch in enumerate(tqdm(loader, desc="Iteration")):
batch = batch.to(device)
if batch.x.shape[0] == 1:
pass
else:
x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch
with torch.no_grad():
pred = model(x, edge_index, edge_attr, batch_f)
y_true.append(batch.y.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu())
y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy()
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)
def predictions(torch_model, jit_model):
pytorch_prediction = eval(torch_model, device, test_loader, evaluator)
print("PyTorch prediction")
print(pytorch_prediction)
mlir_prediction = eval(jit_model, device, test_loader, evaluator)
print("torch-mlir prediction")
print(mlir_prediction)
cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()
### automatic data loading and splitting
dataset = PygGraphPropPredDataset(name='ogbg-molhiv')
split_idx = dataset.get_idx_split()
### automatic evaluator. takes dataset name as input
evaluator = Evaluator("ogbg-molhiv")
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=1, shuffle=True,
num_workers=0)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=1, shuffle=False,
num_workers=0)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=1, shuffle=False,
num_workers=0)
gin = GNN(gnn_type='gin', num_tasks=dataset.num_tasks, num_layer=5, emb_dim=300,
drop_ratio=0.5).to("cpu")
optimizer = optim.Adam(gin.parameters(), lr=0.001)
device = torch.device("cpu")
train(gin, device, train_loader, optimizer, dataset.task_type)
eval(gin, device, valid_loader, evaluator)
gin.eval()
for step, batch in enumerate(tqdm(test_loader, desc="Iteration")):
batch = batch.to(device)
x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch
module = torch_mlir.compile(gin, (x, edge_index, edge_attr, batch_f), output_type="linalg-on-tensors")
break
backend = refbackend.RefBackendLinalgOnTensorsBackend()
compiled = backend.compile(module)
jit_module = backend.load(compiled)
predictions(gin, jit_module)
gnn.py
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
import torch.nn.functional as F
from torch_geometric.nn.inits import uniform
from conv import GNN_node
from torch_scatter import scatter_mean
import time
class GNN(torch.nn.Module):
def __init__(self, num_tasks=10, num_layer=5, emb_dim=300,
gnn_type='gin', residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"):
"""
num_tasks (int): number of labels to be predicted
virtual_node (bool): whether to add virtual node or not
"""
super(GNN, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.graph_pooling = graph_pooling
### GNN to generate node embeddings
self.gnn_node = GNN_node(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual,
gnn_type=gnn_type)
### Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum":
self.pool = global_add_pool
elif self.graph_pooling == "mean":
self.pool = global_mean_pool
elif self.graph_pooling == "max":
self.pool = global_max_pool
elif self.graph_pooling == "attention":
self.pool = GlobalAttention(
gate_nn=torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, 1)))
elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid grcd ..aph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, batch) -> torch.Tensor:
h_node = self.gnn_node(x, edge_index, edge_attr)
h_graph = self.pool(h_node, batch)
return self.graph_pred_linear(h_graph)
if __name__ == '__main__':
GNN(num_tasks=10)
conv.py
import torch
from torch_geometric.nn import MessagePassing
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_add_pool
from ogb.graphproppred.mol_encoder import AtomEncoder, BondEncoder
from torch_geometric.utils import degree
import math
import time
### GIN convolution along the graph structure
class GINConv(MessagePassing):
propagate_type = {'x': torch.Tensor, 'edge_attr': torch.Tensor}
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr="add")
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
self.eps = torch.nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
edge_embedding = self.bond_encoder(edge_attr)
assert isinstance(edge_embedding, torch.Tensor)
return self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding, size=None))
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
### GNN to generate node embedding
class GNN_node(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False, gnn_type='gin'):
'''
emb_dim (int): node embedding dimensionality
num_layer (int): number of GNN message passing layers
'''
super(GNN_node, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
### add residual connection or not
self.residual = residual
if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
###List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for _ in range(num_layer):
self.convs.append(GINConv(emb_dim).jittable())
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
### computing input node embedding
h_list = [self.atom_encoder(x)]
for layer, (conv, norm) in enumerate(zip(self.convs, self.batch_norms)):
assert isinstance(h_list[layer], torch.Tensor)
h = conv(h_list[layer], edge_index, edge_attr)
h = norm(h)
if layer == self.num_layer - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
### Different implementations of Jk-concat
#if self.JK == "last":
node_representation = h_list[-1]
#elif self.JK == "sum":
# node_representation = 0
# for layer in range(self.num_layer + 1):
# node_representation += h_list[layer]
return node_representation
if __name__ == "__main__":
pass
mol_encoder.py
import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()
class AtomEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_atom_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
x_embedding += self.atom_embedding_list[0](x[:,0])
x_embedding += self.atom_embedding_list[1](x[:,1])
x_embedding += self.atom_embedding_list[2](x[:,2])
x_embedding += self.atom_embedding_list[3](x[:,3])
x_embedding += self.atom_embedding_list[4](x[:,4])
x_embedding += self.atom_embedding_list[5](x[:,5])
x_embedding += self.atom_embedding_list[6](x[:,6])
x_embedding += self.atom_embedding_list[7](x[:,7])
x_embedding += self.atom_embedding_list[8](x[:,8])
return x_embedding
class BondEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(BondEncoder, self).__init__()
self.bond_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_bond_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)
def forward(self, edge_attr):
bond_embedding = 0
bond_embedding += self.bond_embedding_list[0](edge_attr[:,0])
bond_embedding += self.bond_embedding_list[1](edge_attr[:,1])
bond_embedding += self.bond_embedding_list[2](edge_attr[:,2])
return bond_embedding
if __name__ == '__main__':
from loader import GraphClassificationPygDataset
dataset = GraphClassificationPygDataset(name = 'tox21')
atom_enc = AtomEncoder(100)
bond_enc = BondEncoder(100)
print(atom_enc(dataset[0].x))
print(bond_enc(dataset[0].edge_attr))
Hi, I am opening this issue to ask a question. I am trying to use one of your examples about graph classification (https://github.com/snap-stanford/ogb/tree/master/examples/graphproppred/mol).
What I have recently done is modify the example to make it possible to successfully create a TorchScript of the model. This because I am trying to lowering down the model to torch-mlir. When trying to do it I encounter an error which, as stated by torch-mlir devs, means that my model has a tuple like x=(0,0). They suggested me to try to change this tuple with a list, like x=[0,0]. Unfortunately, I am new into this and I have not been able to spot the problem. I leave the torch-mlir error here for completeness.
Can you please help me to spot this tuple in order to make your models compatible with torch-mlir ? I leave here the files of your model I am using. Some changes has been done only to make it compatible with TorchScript (and, for simplicity, only the code for the GIN model has been preserved).
Thank you in advance for your help.
main.py
gnn.py
conv.py
mol_encoder.py