A native Java library for Graph Neural Networks.
GCNBuilder
to FastBuilder
Graph Neural Networks (GNNs) are getting more and more popular; they can reason from relational information and perform inference from small datasets. JGNN implements GNNs for native Java applications to supports cross-platform machine learning, such as on Android, without dedicated hardware or firmware.
To install the latest working version, you can include it as a Maven or Gradle dependency
by following the instructions of its JitPack distribution:
As a good out-of-the-box graph neural network for node classification, you can try the following architecture. To create your own graph adjacency matrices, node features, and labels look at the data creation tutorial.
Dataset dataset = new Cora();
Matrix adjacency = dataset.graph().setMainDiagonal(1).setToSymmetricNormalization();
Matrix nodeFeatures = dataset.features();
Matrix nodeLabels = dataset.labels();
Slice nodes = dataset.samples().getSlice().shuffle(100);
long numClasses = nodeLabels.getCols();
ModelBuilder modelBuilder = new FastBuilder(adjacency, nodeFeatures)
.config("reg", 0.005)
.config("hidden", 16)
.config("classes", numClasses)
.layer("h{l+1}=relu(h{l}@matrix(features, hidden, reg)+vector(hidden))")
.layer("h{l+1}=h{l}@matrix(hidden, classes)+vector(classes)")
.rememberAs("0")
.constant("a", 0.9)
.layerRepeat("h{l+1} = a*(dropout(A, 0.5)@h{l})+(1-a)*h{0}", 10)
.classify();
ModelTraining trainer = new ModelTraining()
.setOptimizer(new Adam(0.01))
.setEpochs(300)
.setPatience(100)
.setLoss(new CategoricalCrossEntropy())
.setVerbose(true)
.setValidationLoss(new CategoricalCrossEntropy());
Model model = modelBuilder.getModel()
.init(new XavierNormal())
.train(trainer,
nodes.samplesAsFeatures(),
nodeLabels,
nodes.range(0, 0.6),
nodes.range(0.6, 0.8));
Matrix output = model.predict(nodes.samplesAsFeatures()).get(0).cast(Matrix.class);
double acc = 0;
for(Long node : nodes.range(0.8, 1)) {
Matrix trueLabels = dataset.labels().accessRow(node).asRow();
Tensor nodeOutput = output.accessRow(node).asRow();
acc += nodeOutput.argmax()==trueLabels.argmax()?1:0;
}
System.out.println("Acc\t "+acc/nodes.range(0.8, 1).size());
Feel free to contribute in any way, for example through the issue tracker. Please check out the contribution guidelines when bringing modifications to the code base.
@article{krasanakis2023101459,
title = {JGNN: Graph Neural Networks on native Java},
journal = {SoftwareX},
volume = {23},
pages = {101459},
year = {2023},
issn = {2352-7110},
doi = {https://doi.org/10.1016/j.softx.2023.101459},
url = {https://www.sciencedirect.com/science/article/pii/S2352711023001553},
author = {Emmanouil Krasanakis and Symeon Papadopoulos and Ioannis Kompatsiaris}
}