Closed rusty1s closed 2 years ago
Cool :) I thought about extending the "Loading CSV" tutorial to showcase how one would apply a GNN on this one. I already started integrating the random link split behaviour, see here. The next task would be to create a heterogeneous GNN model, and train it in a supervised fashion against ratings in the training set. WDYT?
Ah yes, that is indeed a nice start already! :)
This example feels a bit different from "typical" link prediction statements, in that I don't think you can really have a contrastive loss with negative edges, as a missing edge in this graph just means we want to predict what rating there should be for each edge of type ('user', 'rates', 'movie')
. So we don't want to train the algorithm to give good separation between "likely" and "unlikely" edges. I think that's fine though, I see this as an edge classification problem and it seems a relevant example, reminiscent of the problem statement of predicting how users would rate products on online stores.
Can you check that my plan for this fits your idea about what you would like?
GNN
with 2 layersDistMult
decoder to get scores for each edge label (from 0 to 5), so here I am - kind of - treating the edge labels as six different edge types. If DistMult
doesn't do the trick, I could try a bilinear, RESCAL
type decoder.softmax
on the 6 scores I got from the decoder per training supervision edge to get something that "looks like" class probabilities, pick the class with the highest probability and use a loss suitable for a multiclass problem statement (such as torch.nn.NLLLoss
)I'm happy to try this approach, I just wanted to check if you already had some kind of plan that is quite different from mine, so we don't waste too much time. Thanks! :)
You are right, it's more of an edge classification problem in which no negative sampling is needed. Nonetheless, the model should be able to predict the ratings of unknown users/movies. Your approach sounds correct, and matches with the one I have in mind. Let me know how it goes :)
Alright, just a quick status update: I've put something together and it is learning but the performance is not amazing so I want to improve it a bit. The average test accuracy gets to about 40% after 400 epochs, which I guess is better than random for a 6 class problem but there are a few things I want to try to make it better before sharing it.
I'm afraid I only have time to do this in my evenings so progress is perhaps a bit slow. Hope that's ok.
Sure, please feel free to submit a PR early, so I can help with it :)
Just a quick message to let you know I’m on holiday. I’ll get back to this again next week.
No worries, we are not in a rush. Enjoy your free days :)
Just a quick message to let you know I’m on holiday. I’ll get back to this again next week.
Hello, I am now also studying how to make link predictions on heterogeneous graph. I would like to ask how your project is progressing now? Thank you very much for your reply.
@liyongkang123: Many apologies for the slow reply! I was on holiday and when I got back work was busy. You can see the progress I've last shared here but I'm not finished yet. I've got a generalisation issue (e.g. the algo works well on the training set but performs poorly on the validation and test sets), and have just started to work again on fixing that.
Is there anything specific that you want to find out about?
Hello! I'm thinking of picking this one up, if that's helpful. Did you have a specific dataset to use in mind already?