google / neural-tangents

Fast and Easy Infinite Neural Networks in Python
https://iclr.cc/virtual_2020/poster_SklD9yrFPS.html
Apache License 2.0
2.28k stars 226 forks source link

How to analyze transductive node classification #151

Open joneswong opened 2 years ago

joneswong commented 2 years ago

Transductive learning is very common, e.g., node classification on Cora, Citeseer, and Pubmed. I intend to analyze the GNN models, e.g., 2-layer GCN, in the NTK regime.

As I have utilize neural_tangents.stax.Aggregate to analyze graph-level tasks, I think it can be generalized to such node-level tasks by:

  1. Deduce the 2-hop neighborhood of each node, make padding to a unified number of nodes, and treat the deduced subgraph as an instance, where the target node is indexed by 0.
  2. Then we can apply DotGeneral before the GlobalSumPool layer to mask node representations other than the 0-th node.

Am I wrong? could you give me a hand? Thanks!

sschoenholz commented 2 years ago

Hi there! Sorry for the delay. I'm not totally familiar with transductive learning in the GP setting. I will note that after the stax.Aggregate layer the kernel will be of shape (batch_size, batch_size, n_nodes, n_nodes). If I understand your setting correctly, the batch_size is not relevant because you have a single large instance. In that case the kernel will effectively have shape (n_nodes, n_nodes). I would think that selecting subgraphs and masking would be analogous to selecting submatrices. Perhaps you could rephrase the transductive task as standard GP inference where the $K{train,train}$ is the kernel formed from the observed nodes and $K{train,test}$ is the kernel between the observed nodes and the nodes that you would like to perform inference on. If you did that then you could use the NT predict function to perform the inference.

Let me know if this makes sense to you or if I have misunderstood something. Your setting sounds interesting and it would be fun to get it working!

yCobanoglu commented 1 year ago

I have done work on Graph Neural Gaussian Processes and Graph Neural Tangent Kernel for Node Classification/ Regression (transductive learning) using the Neural Tangent Library.