MKLab-ITI / JGNN

A Fast Graph Neural Network Library written in Native Java
Apache License 2.0
17 stars 3 forks source link

Automate training for graph classification #11

Closed maniospas closed 2 months ago

maniospas commented 3 months ago

Let ModelTraining or a similar class support graph classification tasks to streamline training.

maniospas commented 3 months ago

Progress: began separation of the base ModelTraining class and NodeClassificaiton as a training strategy. As of v1.3.28, they are both moved under the adhoc package.

maniospas commented 2 months ago

The final scheme for working with trainers is that they should be passed all necessary data and they overload the data setting and batch generation methods internally. The final interface automates attributed graph function training like this:

ModelTraining trainer = new AGFTraining()  // for any attributed graph function.
        .setGraphs(dtrain.graphs)
        .setNodeFeatures(dtrain.features)
        .setGraphLabels(dtrain.labels)
        .setValidationSplit(0.2)
        .setEpochs(300)
        .setOptimizer(new Adam(0.001))
        .setLoss(new CategoricalCrossEntropy())
        //.setNumBatches(10)
        //.setParallelizedStochasticGradientDescent(true)
        .setValidationLoss(new VerboseLoss(new CategoricalCrossEntropy(), new Accuracy()));

Model model = builder.getModel()
        .init(new XavierNormal())
        .train(trainer);