DeepGraphLearning / graphvite

GraphVite: A General and High-performance Graph Embedding System
https://graphvite.io
Apache License 2.0
1.23k stars 154 forks source link

app.solver.model didn't correctly loaded #72

Open qiangxinglin opened 4 years ago

qiangxinglin commented 4 years ago

Hi,

I trained a GraphApplication, and save the model into disk. When I try to load back the model, app.solver.model is empty. And thus will raise such exception during link prediction task.

----> 1 app.link_prediction(file_name='test.txt')

/opt/conda/lib/python3.7/site-packages/graphvite/application/application.py in link_prediction(self, H, T, Y, file_name, 
filter_H, filter_T, filter_file)
433         vertex_embeddings = self.solver.vertex_embeddings
434         context_embeddings = self.solver.context_embeddings
--> 435         model = LinkPredictor(self.solver.model, vertex_embeddings, context_embeddings)
436         model = model.cuda()
437 

/opt/conda/lib/python3.7/site-packages/graphvite/application/network.py in __init__(self, score_function, *embeddings, **kwargs)
 52             self.score_function = score_function
 53         else:
---> 54             self.score_function = getattr(LinkPredictor, score_function)
 55         self.kwargs = kwargs
 56         self.embeddings = nn.ModuleList()

AttributeError: type object 'LinkPredictor' has no attribute ''

I suppose the problem is due to class GraphApplication

def set_parameters(self, model):
    mapping = self.get_mapping(self.graph.id2name, model.graph.name2id)
    self.solver.vertex_embeddings[:] = model.solver.vertex_embeddings[mapping]
    self.solver.context_embeddings[:] = model.solver.context_embeddings[mapping]
KiddoZhu commented 4 years ago

Hi. Thanks for pointing out that bug.

app.solver.model is only set when you call app.train(). So a workaround is to call app.train() with num_epoch=0 and specify other hyperparameters (e.g., model) as you want. This won't touch the learned parameters.