snap-stanford / neural-subgraph-learning-GNN

340 stars 64 forks source link

Scores for subgraphs and non-subgraphs #27

Open jcrangel opened 2 years ago

jcrangel commented 2 years ago

Hello, thanks for coding such a great project. I'm trying to score if a graph is subgraph or not using the code in aligment.py by creating a subgraph using graph, neigh = utils.sample_neigh([target], 7), and scoring using score = model.predict(model(ttarget, tquery)). .Also, for comparison, I'm creating a non subgraph using

Gno = nx.Graph()
Gno.add_edges_from([(43, 39), (43, 14),(43,60)]). 

But I get bigger values for the non-subgraph than the subgraph:

Subgraph score 338.0877380371094
Non subgraph score 487.8809509277344

I'm creating the score correctly? Here's the complete code:


import sys, os
sys.path.insert(0, os.path.abspath(".."))
import argparse
import matplotlib.pyplot as plt
plt.rcParams.update({'font.size': 16})
import random
import networkx as nx
from common import data
from common import models
from common import utils
from subgraph_matching.config import parse_encoder

import torch

def subgraph_score(emb_target, emb_query):
    ttarget = torch.from_numpy(emb_target).float().to(utils.get_device())
    tquery = torch.from_numpy(emb_query).float().to(utils.get_device())
    pred = model.predict(model(ttarget, tquery))
    return pred.item()

parser = argparse.ArgumentParser()

# Now we load the model and a dataset to analyze embeddings on, here ENZYMES.
utils.parse_optimizer(parser)
parse_encoder(parser)
args = parser.parse_args("")
args.model_path = os.path.join("..", args.model_path)

print("Using dataset {}".format(args.dataset))
model = models.OrderEmbedder(1, args.hidden_dim, args)
model.to(utils.get_device())
model.eval()
model.load_state_dict(torch.load(args.model_path,
    map_location=utils.get_device()))

train, test, task = data.load_dataset("enzymes")
motifs = []
for i in range(10):
    graph, neigh = utils.sample_neigh(train, 29)
    motifs.append(graph.subgraph(neigh))

batch = utils.batch_nx_graphs(motifs)
embs = model.emb_model(batch).detach().cpu().numpy()

max_n_edges = max([len(m.edges) for m in motifs])
max_n_nodes = max([len(m) for m in motifs])

target = motifs[4]
emb_target = embs[4]
print('target nodes:',target.nodes)
# nx.draw(target, with_labels=True)

graph, neigh = utils.sample_neigh([target], 7)
# print(graph, neigh)
query = utils.batch_nx_graphs([graph.subgraph(neigh)])
emb_query = model.emb_model(query).detach().cpu().numpy()
# nx.draw(graph.subgraph(neigh), with_labels=True)

print('subgraph nodes:', graph.subgraph(neigh).nodes)
print('Subgraph score',subgraph_score(emb_target, emb_query))

#Small non subgraph
Gno = nx.Graph()
Gno.add_edges_from([(50, 55), (56, 55)])
# nx.draw(Gno, with_labels=True)
query = utils.batch_nx_graphs([Gno])
emb_query = model.emb_model(query).detach().cpu().numpy()
print('Non subgraph nodes:', Gno.nodes)
print('Non subgraph score',subgraph_score(emb_target, emb_query))