Open tinca opened 4 months ago
Hello! Thanks for the suggestion. Opened #8 to track progress on distributing with source code.
With respect to explanations:
The SortPooling example showcases graph classification. This means that you have multiple graphs and want to learn to create a label for each of those. Each graph has its own NxN adjacency matrix and its own matrix of node features NxF (where N is the number of nodes), but also a vector of length D that holds the one-hot encoding of its label. D is the number of labels.
For example, [0, 1, 0, 0] means that we have D=4 types of graphs (let's call them label0, label1, label2, label3), and the particular one is of type label1. If you make predictions of length D, you would know which label your system predicts by taking the argmax. For example, predictions [0.1, 0.3, 0.8, 0.5] are of label2.
The whole point of the GNN architecture for graph classification is to take a graph and extract an invariant prediction vector, where invariance means that shuffling the node index everywhere in the same way does not affect the outcome. (e.g., the line graph A--B should be treated the same as the line graph B--A). This is achieved by applying first some equivariant predictive mechanism, which effectively means that the i-th row of the results corresponds to the i-th node), and then applying a trick to make it invariant.
The simplest equivariant mechanism would be to apply the same dense neural layer on all nodes, but GNNs are more complicated than MLPs in that they also use the adjacency matrix to force every node to take into account what is being learned for its neighbors and -through them- take into consideration the whole graph.
SortPooling is one mechanism that can perform a conversion from equivariance to invariance. You can have more learning neural layers after making an equivariant representation of node embeddings into an invariant representation for the whole graph. Note that everything is differentiable so that backpropagation can actually learn architectural parameters by comparing the final invariant output (a tensor with length D) with its one-hot encoding of labels of length also D. The comparison of these two is made through the loss function.
I guess the part where you are confused about is the command labels.add(new DenseTensor(2).put(0, 1.0));
in the source code. This means that we are creating a DenseTensor (that is, a Vector) of two elements, and we put the value of 1.0 in position 0. If you print it, it will show something like [1.0, 0.0]. The next command new DenseTensor(2).put(1, 1.0)
sets value of 1.0 in position 1. If you print it, it will show [0.0, 1.0]. Basically, D=2 here and our two labels are whether the graph is of type ff or ft. Notice that ft is a certain type of graph and ff is a permutation that changes something in the node features (example obtained from #2 .
The reason our dataset is twice the number of its given size is that, for every positive example ft, we created a negative example ff. Positives and negatives are basically our two classes. That is, we want to learn to differentiate between ft and ff. You could create more negative examples for each positive to improve the learning task too (given that positive samples all have the same property and we want to be able to pinpoint that), or create adversarial learning strategies if you had only some positive examples (this is a huge subject, so I would rather not address it at all).
Honestly, if this is not your intended use case, I would recommend starting from node classification instead (and not graph classification). In this case, look at architectures such as GCN and APPNP in examples/graphClassification
. Node classification is easier to understand in toy settings.
You may also want to read parts of the tutorial though the same issues as in #6 will occur because the tutorial also needs to be updated to the 1.2 API. You can also skip to the part of the tutorial the covers SortPooling here.
Hi maniospas,
Just back from holidays, that' s why my late answer. Thank you for your detailed answer. What you writes makes an excellent README candidate within the examples! And thank you for taking my request on the source artifacts.
While away, I more or less figured out how labels are supposed to work, but it is always good to have rest assured. I do need graph classification as I've been thinking on relating chemical structures with their physical/chemical properties (that is after train, predicting latter ones based on the structures).
I did read the tutorials, but admittedly in a bit of hurry, so it's time to re-read them more carefully.
Is this Issues section an appropriate place for this type of conversations, or do you have some forum somewhere else?
Hi and thanks for the feedback. I was also on holidays so I am a bit late with the reply.
I am currently creating an updated tutorial that covers these needs for understanding, plus other weak points in existing explanations. Should be finished by next week.
Opening issues is the preferred way right now - in fact, requests for clarification especially welcome.
Hi, After succefully running SortPooling, I am experimenting with my own data set while keepenig the architecture and have a couple of questions:
Above question certainly shows my basic understanding is incomplete, so no wonder that I am getting exceptions on running the architecture with my data, like:
java.lang.IllegalArgumentException: Mismatched matrix sizes between SparseMatrix (13,13) 12/169 entries and DenseMatrix (features 1, hidden 8)
And finally a little request: it would be so useful if the source artifact was also published to jitpack. This way IDE can show the javaDocs/sources as soon as they are needed. Thank you, in advance.