Closed dvlp-r closed 1 year ago
What happens if you run this as
edge_embedding = self.bond_encoder(edge_attr)
assert isinstance(edge_embedding, Tensor)
return self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding, size=None))
Hi, thank you for your reply. Your suggestion actually helped me, but I am facing several problems. Now I have created a file to try to export the gin model.
import torch
from ogb.graphproppred.mol_encoder import BondEncoder, AtomEncoder
from torch_geometric.loader import DataLoader
import torch.optim as optim
from torch_geometric.nn import MessagePassing, global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, \
Set2Set
import torch.nn.functional as F
from gnn import GNN
import argparse
import numpy as np
from tqdm import tqdm
# importing OGB
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
cls_criterion = torch.nn.BCEWithLogitsLoss()
reg_criterion = torch.nn.MSELoss()
### GIN convolution along the graph structure
class GINConv(MessagePassing):
propagate_type = {'x': torch.Tensor, 'edge_attr': torch.Tensor}
def __init__(self, emb_dim):
'''
emb_dim (int): node embedding dimensionality
'''
super(GINConv, self).__init__(aggr="add")
self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, emb_dim))
self.eps = torch.nn.Parameter(torch.Tensor([0]))
self.bond_encoder = BondEncoder(emb_dim=emb_dim)
def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
edge_embedding = self.bond_encoder(edge_attr)
assert isinstance(edge_embedding, torch.Tensor)
return self.mlp((1 + self.eps) * x + self.propagate(edge_index, x=x, edge_attr=edge_embedding, size=None))
def message(self, x_j, edge_attr):
return F.relu(x_j + edge_attr)
def update(self, aggr_out):
return aggr_out
class GNN(torch.nn.Module):
def __init__(self, num_tasks=10, num_layer=5, emb_dim=300,
gnn_type='gin', residual=False, drop_ratio=0.5, JK="last", graph_pooling="mean"):
"""
num_tasks (int): number of labels to be predicted
virtual_node (bool): whether to add virtual node or not
"""
super(GNN, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
self.emb_dim = emb_dim
self.num_tasks = num_tasks
self.graph_pooling = graph_pooling
### GNN to generate node embeddings
self.gnn_node = GNN_node(num_layer, emb_dim, JK=JK, drop_ratio=drop_ratio, residual=residual,
gnn_type=gnn_type)
### Pooling function to generate whole-graph embeddings
if self.graph_pooling == "sum":
self.pool = global_add_pool
elif self.graph_pooling == "mean":
self.pool = global_mean_pool
elif self.graph_pooling == "max":
self.pool = global_max_pool
elif self.graph_pooling == "attention":
self.pool = GlobalAttention(
gate_nn=torch.nn.Sequential(torch.nn.Linear(emb_dim, 2 * emb_dim), torch.nn.BatchNorm1d(2 * emb_dim),
torch.nn.ReLU(), torch.nn.Linear(2 * emb_dim, 1)))
elif self.graph_pooling == "set2set":
self.pool = Set2Set(emb_dim, processing_steps=2)
else:
raise ValueError("Invalid graph pooling type.")
if graph_pooling == "set2set":
self.graph_pred_linear = torch.nn.Linear(2 * self.emb_dim, self.num_tasks)
else:
self.graph_pred_linear = torch.nn.Linear(self.emb_dim, self.num_tasks)
def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor, batch) -> torch.Tensor:
"""
forward function used for prediction
"""
h_node = self.gnn_node(x, edge_index, edge_attr)
h_graph = self.pool(h_node, batch)
return self.graph_pred_linear(h_graph)
### GNN to generate node embedding
class GNN_node(torch.nn.Module):
"""
Output:
node representations
"""
def __init__(self, num_layer, emb_dim, drop_ratio=0.5, JK="last", residual=False, gnn_type='gin'):
'''
emb_dim (int): node embedding dimensionality
num_layer (int): number of GNN message passing layers
'''
super(GNN_node, self).__init__()
self.num_layer = num_layer
self.drop_ratio = drop_ratio
self.JK = JK
### add residual connection or not
self.residual = residual
self.layers = [1, 2, 3, 4, 5]
if self.num_layer < 2:
raise ValueError("Number of GNN layers must be greater than 1.")
self.atom_encoder = AtomEncoder(emb_dim)
###List of GNNs
self.convs = torch.nn.ModuleList()
self.batch_norms = torch.nn.ModuleList()
for layer in range(num_layer):
self.convs.append(GINConv(emb_dim).jittable())
self.batch_norms.append(torch.nn.BatchNorm1d(emb_dim))
def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: torch.Tensor) -> torch.Tensor:
### computing input node embedding
h_list = [self.atom_encoder(x)]
for layer in range(self.num_layer):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
h = self.batch_norms[layer](h)
if layer == self.num_layer - 1:
# remove relu for the last layer
h = F.dropout(h, self.drop_ratio, training=self.training)
else:
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
if self.residual:
h += h_list[layer]
h_list.append(h)
### Different implementations of Jk-concat
if self.JK == "last":
node_representation = h_list[-1]
elif self.JK == "sum":
node_representation = 0
for layer in range(self.num_layer + 1):
node_representation += h_list[layer]
return node_representation
def train(model, device, loader, optimizer, task_type):
"""
function used to train a model
"""
model.train()
for step, batch in enumerate(tqdm(loader, desc="Iteration")):
batch = batch.to(device)
x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch
if batch.x.shape[0] == 1 or batch.batch[-1] == 0:
pass
else:
pred = model(x, edge_index, edge_attr, batch_f)
optimizer.zero_grad()
## ignore nan targets (unlabeled) when computing training loss.
is_labeled = batch.y == batch.y
if "classification" in task_type:
loss = cls_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
else:
loss = reg_criterion(pred.to(torch.float32)[is_labeled], batch.y.to(torch.float32)[is_labeled])
loss.backward()
optimizer.step()
def eval(model, device, loader, evaluator):
"""
function used to evaluate the model and make inference
"""
model.eval()
y_true = []
y_pred = []
for step, batch in enumerate(tqdm(loader, desc="Iteration")):
batch = batch.to(device)
x, edge_index, edge_attr, batch_f = batch.x, batch.edge_index, batch.edge_attr, batch.batch
if batch.x.shape[0] == 1:
pass
else:
with torch.no_grad():
pred = model(x, edge_index, edge_attr, batch_f)
y_true.append(batch.y.view(pred.shape).detach().cpu())
y_pred.append(pred.detach().cpu())
y_true = torch.cat(y_true, dim=0).numpy()
y_pred = torch.cat(y_pred, dim=0).numpy()
input_dict = {"y_true": y_true, "y_pred": y_pred}
return evaluator.eval(input_dict)
def main():
"""
main function which allow the user to train a new model and make inference (graph classification)
"""
# Training settings
parser = argparse.ArgumentParser(description='GNN baselines on ogbgmol* data with Pytorch Geometrics')
parser.add_argument('--drop_ratio', type=float, default=0.5,
help='dropout ratio (default: 0.5)')
parser.add_argument('--num_layer', type=int, default=5,
help='number of GNN message passing layers (default: 5)')
parser.add_argument('--emb_dim', type=int, default=300,
help='dimensionality of hidden units in GNNs (default: 300)')
parser.add_argument('--batch_size', type=int, default=32,
help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100,
help='number of epochs to train (default: 100)')
parser.add_argument('--num_workers', type=int, default=0,
help='number of workers (default: 0)')
args = parser.parse_args()
device = torch.device("cpu")
### automatic data loading and splitting
dataset = PygGraphPropPredDataset(name="ogbg-molhiv")
split_idx = dataset.get_idx_split()
### automatic evaluator. takes dataset name as input
evaluator = Evaluator("ogbg-molhiv")
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=args.batch_size, shuffle=True,
num_workers=args.num_workers)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=args.batch_size, shuffle=False,
num_workers=args.num_workers)
model = GNN(gnn_type='gin', num_tasks=dataset.num_tasks, num_layer=args.num_layer, emb_dim=args.emb_dim,
drop_ratio=args.drop_ratio).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
valid_curve = []
test_curve = []
train_curve = []
model_scripted = torch.jit.script(model) # Export to TorchScript
model_scripted.save("gin-script.pt") # Save
for epoch in range(1, 1 + 1):
print("=====Epoch {}".format(epoch))
print('Training...')
train(model, device, train_loader, optimizer, dataset.task_type)
print('Evaluating...')
train_perf = eval(model, device, train_loader, evaluator)
# valid_perf = eval(model, device, valid_loader, evaluator)
# test_perf = eval(model, device, test_loader, evaluator)
# print({'Train': train_perf, 'Validation': valid_perf, 'Test': test_perf})
# train_curve.append(train_perf[dataset.eval_metric])
# valid_curve.append(valid_perf[dataset.eval_metric])
# test_curve.append(test_perf[dataset.eval_metric])
# if 'classification' in dataset.task_type:
# best_val_epoch = np.argmax(np.array(valid_curve))
# best_train = max(train_curve)
# else:
# best_val_epoch = np.argmin(np.array(valid_curve))
# best_train = min(train_curve)
print('Finished training!')
# print('Best validation score: {}'.format(valid_curve[best_val_epoch]))
# print('Test score: {}'.format(test_curve[best_val_epoch]))
if __name__ == "__main__":
main()
I am now blocked with the following error.
Expected integer literal for index. ModuleList/Sequential indexing is only supported with integer literals. Enumeration is supported, e.g. 'for index, v in enumerate(self): ...':
File "/Users/dvlpr/PyCharmProjects/gnn-acceleration-master-thesis/gnn/gin.py", line 146
for layer in range(self.num_layer):
h = self.convs[layer](h_list[layer], edge_index, edge_attr)
~~~~~~~~~~~~~~~~~ <--- HERE
h = self.batch_norms[layer](h)
I have faced this problem also in the mol_encoder file but I have been able to solve it in the following way, but now I am not able to make it work.
import torch
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
full_atom_feature_dims = get_atom_feature_dims()
full_bond_feature_dims = get_bond_feature_dims()
@torch.jit.interface
class ModuleInterface(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
pass
class AtomEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_atom_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
def forward(self, x):
x_embedding = 0
for i in range(x.shape[1]):
submodule: ModuleInterface = self.atom_embedding_list[i]
x_embedding += submodule.forward(x[:,i])
return x_embedding
class BondEncoder(torch.nn.Module):
def __init__(self, emb_dim):
super(BondEncoder, self).__init__()
self.bond_embedding_list = torch.nn.ModuleList()
for i, dim in enumerate(full_bond_feature_dims):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.bond_embedding_list.append(emb)
def forward(self, edge_attr):
bond_embedding = 0
for i in range(edge_attr.shape[1]):
submodule: ModuleInterface = self.bond_embedding_list[i]
bond_embedding += submodule.forward(edge_attr[:,i])
return bond_embedding
if __name__ == '__main__':
from loader import GraphClassificationPygDataset
dataset = GraphClassificationPygDataset(name = 'tox21')
atom_enc = AtomEncoder(100)
bond_enc = BondEncoder(100)
print(atom_enc(dataset[0].x))
print(bond_enc(dataset[0].edge_attr))
Thank you in advance for your help.
Yeah, it can get a bit tricky to make TorchScript work. In your case, I believe you need to directly iterate over convs
and batch_norms
:
for conv, norm in zip(self.convs, self.batch_norms):
...
Hi @rusty1s, thank you for your help. I have finally saved my model in TorchScript. However, I still have two problems and I hope you can help me with them.
When I try to re-load my script and make inference, the following error appears
Traceback (most recent call last):
File "/Users/dvlpr/PyCharmProjects/.../gnn/gin.py", line 327, in <module>
main()
File "/Users/dvlpr/PyCharmProjects/.../gnn/gin.py", line 319, in main
print(model_load(x, edge_index, edge_attr, batch_f))
File "/Users/dvlpr/miniconda3/envs/pytorch/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
File "code/__torch__.py", line 21, in forward
_0 = __torch__.torch_geometric.nn.pool.glob.global_mean_pool
gnn_node = self.gnn_node
h_node = (gnn_node).forward(x, edge_index, edge_attr, )
~~~~~~~~~~~~~~~~~ <--- HERE
h_graph = _0(h_node, batch, None, )
graph_pred_linear = self.graph_pred_linear
File "code/__torch__.py", line 45, in forward
_2 = uninitialized(Tensor)
atom_encoder = self.atom_encoder
_3 = (atom_encoder).forward(x, )
~~~~~~~~~~~~~~~~~~~~~ <--- HERE
ops.prim.RaiseException("AssertionError: ")
return _2
File "code/__torch__/ogb/graphproppred/mol_encoder.py", line 15, in forward
_0 = torch.select(torch.slice(x), 1, i)
x_embedding0 = torch.add((submodule).forward(_0, ), x_embedding)
x_embedding = annotate(int, x_embedding0)
~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
return x_embedding
class BondEncoder(Module):
Traceback of TorchScript, original code (most recent call last):
File "/Users/dvlpr/PyCharmProjects/.../gnn/gin.py", line 99, in forward
"""
h_node = self.gnn_node(x, edge_index, edge_attr)
~~~~~~~~~~~~~ <--- HERE
h_graph = self.pool(h_node, batch)
File "/Users/dvlpr/PyCharmProjects/.../gnn/gin.py", line 144, in forward
### computing input node embedding
h_list = [self.atom_encoder(x)]
~~~~~~~~~~~~~~~~~ <--- HERE
for layer, (conv, norm) in enumerate(zip(self.convs, self.batch_norms)):
File "/Users/dvlpr/miniconda3/envs/pytorch/lib/python3.10/site-packages/ogb/graphproppred/mol_encoder.py", line 28, in forward
for i in range(x.shape[1]):
submodule: ModuleInterface = self.atom_embedding_list[i]
x_embedding += submodule.forward(x[:,i])
~~~~~~~~~~~ <--- HERE
return x_embedding
RuntimeError: Cannot input a tensor of dimension other than 0 as a scalar argument
Looking on the web I have not found any solution for it. Also, what I am trying to do is lowering this model down to MLIR using torch-mlir. However, when trying to do it, an error tell me that somewhere in my model I have a tuple that looks like x = (0,0), but I do not recognize any tuple like this. Am I missing something?
I am not entirely sure. Can you try to replace +=
with add_()
?
x_embedding.add_(submodule.forward(x[:,i]))
Unfortunately it does not work since x_embedding is an integer.
File "/Users/dvlpr/miniconda3/envs/pytorch/lib/python3.10/site-packages/ogb/graphproppred/mol_encoder.py", line 28, in forward
x_embedding.add_(submodule.forward(x[:,i]))
AttributeError: 'int' object has no attribute 'add_'
Why is x_embedding
an integer? it should be a tensor IMO.
that code is part of the mol_encoder.py
file (used for the MolHIV dataset) provided by OGB. The whole code I am using is provided by Open Graph Benchmark. The only change I have made is using the jit interface to create the torch script.
I see, I think this should be changed to:
x_embeddings = [self.atom_embedding_list[i](x[:, i]) for i in rang(x.shape[1])
return sum(x_embeddings)
hi @rusty1s, thank you for your help. Apparently the problem was the jit interface I was using in the mol_encoder.py
file. After some testing I found that deleting the jit interface and hardcoding the for cycles solved the problem. I can now use the scripted model to correctly make inference.
Last thing that I would like to ask you is if you have any suggestion about the error I am receiving when trying to lowering down the model to torch-mlir. I know it is not the purpose of this issue but I am not figuring out where a tuple like (0,0) could be in my model. Anyway, I really thank you so much for your help. Feel free to close this issue.
I am not familiar with torch-mlir
at all :(
thank you anyway, I will try to ask on the OGB repo to see if someone has experience with it. Thank you again!
🐛 Describe the bug
Hi, I am trying to save the torchScript of a GNN model created with PyTorch geometric. The model is the one provided by OGB in its GitHub repository (https://github.com/snap-stanford/ogb/tree/master/examples/graphproppred/mol).
After having done a modification in the OGB mol_encoder class (used a jit interface to avoid the expected integer literal for index error) and set the re-defined GINConv operator as jittable() and added the propagate and forward types, I am now encountering this error and I do not know how to solve it.
Thank you in advance for your help, if you do not know about this error but are able to provide me a working solution to save the TorchScript of the OGB model I linked, I would be really thankful.
Thanks everyone in advance for your help.
Environment
conda
,pip
, source): piptorch-scatter
):